diff --git a/coderd/database/gen/fake/main.go b/coderd/database/gen/fake/main.go index b74f5d18d7..2d399192fa 100644 --- a/coderd/database/gen/fake/main.go +++ b/coderd/database/gen/fake/main.go @@ -1,10 +1,12 @@ package main import ( + "fmt" "go/format" "go/token" "log" "os" + "path" "github.com/dave/dst" "github.com/dave/dst/decorator" @@ -65,6 +67,76 @@ func run() error { } for _, fn := range funcs { + var bodyStmts []dst.Stmt + if len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" { + /* + err := validateDatabaseType(arg) + if err != nil { + return database.User{}, err + } + */ + bodyStmts = append(bodyStmts, &dst.AssignStmt{ + Lhs: []dst.Expr{dst.NewIdent("err")}, + Tok: token.DEFINE, + Rhs: []dst.Expr{ + &dst.CallExpr{ + Fun: &dst.Ident{ + Name: "validateDatabaseType", + }, + Args: []dst.Expr{dst.NewIdent("arg")}, + }, + }, + }) + returnStmt := &dst.ReturnStmt{ + Results: []dst.Expr{}, // Filled below. + } + bodyStmts = append(bodyStmts, &dst.IfStmt{ + Cond: &dst.BinaryExpr{ + X: dst.NewIdent("err"), + Op: token.NEQ, + Y: dst.NewIdent("nil"), + }, + Body: &dst.BlockStmt{ + List: []dst.Stmt{ + returnStmt, + }, + }, + Decs: dst.IfStmtDecorations{ + NodeDecs: dst.NodeDecs{ + After: dst.EmptyLine, + }, + }, + }) + for _, r := range fn.Func.Results.List { + switch typ := r.Type.(type) { + case *dst.StarExpr, *dst.ArrayType: + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil")) + case *dst.Ident: + if typ.Path != "" { + returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name))) + } else { + switch typ.Name { + case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr", + "int8", "int16", "int32", "int64", "int", + "byte", "rune", + "float32", "float64", + "complex64", "complex128": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0")) + case "string": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\"")) + case "bool": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false")) + case "error": + returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err")) + default: + panic(fmt.Sprintf("unknown ident: %#v", r.Type)) + } + } + default: + panic(fmt.Sprintf("unknown return type: %T", r.Type)) + } + } + } decl, ok := declByName[fn.Name] if !ok { // Not implemented! @@ -90,21 +162,19 @@ func run() error { }, }, Body: &dst.BlockStmt{ - List: []dst.Stmt{ - &dst.ExprStmt{ - X: &dst.CallExpr{ - Fun: &dst.Ident{ - Name: "panic", - }, - Args: []dst.Expr{ - &dst.BasicLit{ - Kind: token.STRING, - Value: "\"Not implemented\"", - }, + List: append(bodyStmts, &dst.ExprStmt{ + X: &dst.CallExpr{ + Fun: &dst.Ident{ + Name: "panic", + }, + Args: []dst.Expr{ + &dst.BasicLit{ + Kind: token.STRING, + Value: "\"Not implemented\"", }, }, }, - }, + }), }, } } @@ -178,9 +248,25 @@ func readStoreInterface() ([]storeMethod, error) { if t == nil { continue } + var ( + ident *dst.Ident + ok bool + ) for _, f := range t.List { - ident, ok := f.Type.(*dst.Ident) - if !ok { + switch typ := f.Type.(type) { + case *dst.StarExpr: + ident, ok = typ.X.(*dst.Ident) + if !ok { + continue + } + case *dst.ArrayType: + ident, ok = typ.Elt.(*dst.Ident) + if !ok { + continue + } + case *dst.Ident: + ident = typ + default: continue } if !ident.IsExported() {