package gentest_test import ( "fmt" "go/ast" "go/parser" "go/token" "slices" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestCustomQueriesSynced makes sure the manual custom queries in modelqueries.go // are synced with the autogenerated queries.sql.go. This should probably be // autogenerated, but it's not atm and this is easy to throw in to elevate a better // error message. // // If this breaks, and is hard to fix, you can t.Skip() it. It is not a critical // test. Ping @Emyrk to fix it again. func TestCustomQueriesSyncedRowScan(t *testing.T) { t.Parallel() funcsToTrack := map[string]string{ "GetTemplatesWithFilter": "GetAuthorizedTemplates", "GetWorkspaces": "GetAuthorizedWorkspaces", "GetUsers": "GetAuthorizedUsers", } // Scan custom var custom []string for _, fn := range funcsToTrack { custom = append(custom, fn) } customFns := parseFile(t, "../modelqueries.go", func(name string) bool { return slices.Contains(custom, name) }) generatedFns := parseFile(t, "../queries.sql.go", func(name string) bool { _, ok := funcsToTrack[name] return ok }) merged := customFns for k, v := range generatedFns { merged[k] = v } for a, b := range funcsToTrack { a, b := a, b if !compareFns(t, a, b, merged[a], merged[b]) { //nolint:revive defer func() { // Run this at the end so the suggested fix is the last thing printed. t.Errorf("The functions %q and %q need to have identical 'rows.Scan()' "+ "and 'db.QueryContext()' arguments in their function bodies. "+ "Make sure to copy the function body from the autogenerated %q body. "+ "Specifically the parameters for 'rows.Scan()' and 'db.QueryContext()'.", a, b, a) }() } } } type parsedFunc struct { RowScanArgs []ast.Expr QueryArgs []ast.Expr } func parseFile(t *testing.T, filename string, trackFunc func(name string) bool) map[string]*parsedFunc { fset := token.NewFileSet() f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution) require.NoErrorf(t, err, "failed to parse file %q", filename) parsed := make(map[string]*parsedFunc) for _, decl := range f.Decls { if fn, ok := decl.(*ast.FuncDecl); ok { if trackFunc(fn.Name.Name) { parsed[fn.Name.String()] = &parsedFunc{ RowScanArgs: pullRowScanArgs(fn), QueryArgs: pullQueryArgs(fn), } } } } return parsed } func compareFns(t *testing.T, aName, bName string, a, b *parsedFunc) bool { if a == nil { t.Errorf("The function %q is missing", aName) return false } if b == nil { t.Errorf("The function %q is missing", bName) return false } r := compareArgs(t, "rows.Scan() arguments", aName, bName, a.RowScanArgs, b.RowScanArgs) if len(a.QueryArgs) > 2 && len(b.QueryArgs) > 2 { // This is because the actual query param name is different. One uses the // const, the other uses a variable that is a mutation of the original query. a.QueryArgs[1] = b.QueryArgs[1] } q := compareArgs(t, "db.QueryContext() arguments", aName, bName, a.QueryArgs, b.QueryArgs) return r && q } func compareArgs(t *testing.T, argType string, aName, bName string, a, b []ast.Expr) bool { return assert.Equal(t, argList(t, a), argList(t, b), "mismatched %s for %s and %s", argType, aName, bName) } func argList(t *testing.T, args []ast.Expr) []string { defer func() { if r := recover(); r != nil { t.Errorf("Recovered in f reading arg names: %s", r) } }() var argNames []string for _, arg := range args { argname := "unknown" // This is "&i.Arg" style stuff if unary, ok := arg.(*ast.UnaryExpr); ok { argname = unary.X.(*ast.SelectorExpr).Sel.Name } if ident, ok := arg.(*ast.Ident); ok { argname = ident.Name } if sel, ok := arg.(*ast.SelectorExpr); ok { argname = sel.Sel.Name } if call, ok := arg.(*ast.CallExpr); ok { // Eh, this is pg.Array style stuff. Do a best effort. argname = fmt.Sprintf("call(%d)", len(call.Args)) if fnCall, ok := call.Fun.(*ast.SelectorExpr); ok { argname = fmt.Sprintf("%s(%d)", fnCall.Sel.Name, len(call.Args)) } } if argname == "unknown" { t.Errorf("Unknown arg, cannot parse: %T", arg) } argNames = append(argNames, argname) } return argNames } func pullQueryArgs(fn *ast.FuncDecl) []ast.Expr { for _, exp := range fn.Body.List { // find "rows, err :=" if assign, ok := exp.(*ast.AssignStmt); ok { if len(assign.Lhs) == 2 { if id, ok := assign.Lhs[0].(*ast.Ident); ok && id.Name == "rows" { // This is rows, err := query := assign.Rhs[0].(*ast.CallExpr) if qSel, ok := query.Fun.(*ast.SelectorExpr); ok && qSel.Sel.Name == "QueryContext" { return query.Args } } } } } return nil } func pullRowScanArgs(fn *ast.FuncDecl) []ast.Expr { for _, exp := range fn.Body.List { if forStmt, ok := exp.(*ast.ForStmt); ok { // This came from the debugger window and tracking it down. rowScan := (forStmt.Body. // Second statement in the for loop is the if statement // with rows.can List[1].(*ast.IfStmt). // This is the err := rows.Scan() Init.(*ast.AssignStmt). // Rhs is the row.Scan part Rhs)[0].(*ast.CallExpr) return rowScan.Args } } return nil }