Files
coder/scripts/dbgen/main.go
Thomas Kosiewski 74e1d5c4b6 feat: implement OAuth2 dynamic client registration (RFC 7591/7592) (#18645)
# Implement OAuth2 Dynamic Client Registration (RFC 7591/7592)

This PR implements OAuth2 Dynamic Client Registration according to RFC 7591 and Client Configuration Management according to RFC 7592. These standards allow OAuth2 clients to register themselves programmatically with Coder as an authorization server.

Key changes include:

1. Added database schema extensions to support RFC 7591/7592 fields in the `oauth2_provider_apps` table
2. Implemented `/oauth2/register` endpoint for dynamic client registration (RFC 7591)
3. Added client configuration management endpoints (RFC 7592):
   - GET/PUT/DELETE `/oauth2/clients/{client_id}`
   - Registration access token validation middleware

4. Added comprehensive validation for OAuth2 client metadata:
   - URI validation with support for custom schemes for native apps
   - Grant type and response type validation
   - Token endpoint authentication method validation

5. Enhanced developer documentation with:
   - RFC compliance guidelines
   - Testing best practices to avoid race conditions
   - Systematic debugging approaches for OAuth2 implementations

The implementation follows security best practices from the RFCs, including proper token handling, secure defaults, and appropriate error responses. This enables third-party applications to integrate with Coder's OAuth2 provider capabilities programmatically.
2025-07-03 18:33:47 +02:00

680 lines
17 KiB
Go

package main
import (
"bufio"
"bytes"
"fmt"
"go/format"
"go/token"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
"strings"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/decorator/resolver/goast"
"github.com/dave/dst/decorator/resolver/guess"
"golang.org/x/tools/imports"
"golang.org/x/xerrors"
)
var (
funcs []querierFunction
funcByName map[string]struct{}
)
func init() {
var err error
funcs, err = readQuerierFunctions()
if err != nil {
panic(err)
}
funcByName = map[string]struct{}{}
for _, f := range funcs {
funcByName[f.Name] = struct{}{}
}
}
func main() {
err := run()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "error: %s\n", err)
os.Exit(1)
}
}
func run() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmem", "dbmem.go"), "q", "FakeQuerier", func(_ stubParams) string {
return `panic("not implemented")`
})
if err != nil {
return xerrors.Errorf("stub dbmem: %w", err)
}
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbmetrics", "querymetrics.go"), "m", "queryMetricsStore", func(params stubParams) string {
return fmt.Sprintf(`
start := time.Now()
%s := m.s.%s(%s)
m.queryLatencies.WithLabelValues("%s").Observe(time.Since(start).Seconds())
return %s
`, params.Returns, params.FuncName, params.Parameters, params.FuncName, params.Returns)
})
if err != nil {
return xerrors.Errorf("stub dbmetrics: %w", err)
}
err = orderAndStubDatabaseFunctions(filepath.Join(databasePath, "dbauthz", "dbauthz.go"), "q", "querier", func(_ stubParams) string {
return `panic("not implemented")`
})
if err != nil {
return xerrors.Errorf("stub dbauthz: %w", err)
}
err = generateUniqueConstraints()
if err != nil {
return xerrors.Errorf("generate unique constraints: %w", err)
}
err = generateForeignKeyConstraints()
if err != nil {
return xerrors.Errorf("generate foreign key constraints: %w", err)
}
return nil
}
// generateUniqueConstraints generates the UniqueConstraint enum.
func generateUniqueConstraints() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
if err != nil {
return err
}
defer dump.Close()
var uniqueConstraints []string
dumpScanner := bufio.NewScanner(dump)
query := ""
for dumpScanner.Scan() {
line := strings.TrimSpace(dumpScanner.Text())
switch {
case strings.HasPrefix(line, "--"):
case line == "":
case strings.HasSuffix(line, ";"):
query += line
if strings.Contains(query, "UNIQUE") || strings.Contains(query, "PRIMARY KEY") {
uniqueConstraints = append(uniqueConstraints, query)
}
query = ""
default:
query += line + " "
}
}
if err = dumpScanner.Err(); err != nil {
return err
}
s := &bytes.Buffer{}
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
package database
`)
_, _ = fmt.Fprint(s, `
// UniqueConstraint represents a named unique constraint on a table.
type UniqueConstraint string
// UniqueConstraint enums.
const (
`)
for _, query := range uniqueConstraints {
name := ""
switch {
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
name = strings.Split(query, " ")[6]
case strings.Contains(query, "CREATE UNIQUE INDEX"):
name = strings.Split(query, " ")[3]
default:
return xerrors.Errorf("unknown unique constraint format: %s", query)
}
_, _ = fmt.Fprintf(s, "\tUnique%s UniqueConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
}
_, _ = fmt.Fprint(s, ")\n")
outputPath := filepath.Join(databasePath, "unique_constraint.go")
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
Comments: true,
})
if err != nil {
return err
}
return os.WriteFile(outputPath, data, 0o600)
}
// generateForeignKeyConstraints generates the ForeignKeyConstraint enum.
func generateForeignKeyConstraints() error {
localPath, err := localFilePath()
if err != nil {
return err
}
databasePath := filepath.Join(localPath, "..", "..", "..", "coderd", "database")
dump, err := os.Open(filepath.Join(databasePath, "dump.sql"))
if err != nil {
return err
}
defer dump.Close()
var foreignKeyConstraints []string
dumpScanner := bufio.NewScanner(dump)
query := ""
for dumpScanner.Scan() {
line := strings.TrimSpace(dumpScanner.Text())
switch {
case strings.HasPrefix(line, "--"):
case line == "":
case strings.HasSuffix(line, ";"):
query += line
if strings.Contains(query, "FOREIGN KEY") {
foreignKeyConstraints = append(foreignKeyConstraints, query)
}
query = ""
default:
query += line + " "
}
}
if err := dumpScanner.Err(); err != nil {
return err
}
s := &bytes.Buffer{}
_, _ = fmt.Fprint(s, `// Code generated by scripts/dbgen/main.go. DO NOT EDIT.
package database
`)
_, _ = fmt.Fprint(s, `
// ForeignKeyConstraint represents a named foreign key constraint on a table.
type ForeignKeyConstraint string
// ForeignKeyConstraint enums.
const (
`)
for _, query := range foreignKeyConstraints {
name := ""
switch {
case strings.Contains(query, "ALTER TABLE") && strings.Contains(query, "ADD CONSTRAINT"):
name = strings.Split(query, " ")[6]
default:
return xerrors.Errorf("unknown foreign key constraint format: %s", query)
}
_, _ = fmt.Fprintf(s, "\tForeignKey%s ForeignKeyConstraint = %q // %s\n", nameFromSnakeCase(name), name, query)
}
_, _ = fmt.Fprint(s, ")\n")
outputPath := filepath.Join(databasePath, "foreign_key_constraint.go")
data, err := imports.Process(outputPath, s.Bytes(), &imports.Options{
Comments: true,
})
if err != nil {
return err
}
return os.WriteFile(outputPath, data, 0o600)
}
type stubParams struct {
FuncName string
Parameters string
Returns string
}
// orderAndStubDatabaseFunctions orders the functions in the file and stubs them.
// This is useful for when we want to add a new function to the database and
// we want to make sure that it's ordered correctly.
//
// querierFuncs is a list of functions that are in the database.
// file is the path to the file that contains all the functions.
// structName is the name of the struct that contains the functions.
// stub is a string that will be used to stub the functions.
func orderAndStubDatabaseFunctions(filePath, receiver, structName string, stub func(params stubParams) string) error {
declByName := map[string]*dst.FuncDecl{}
packageName := filepath.Base(filepath.Dir(filePath))
contents, err := os.ReadFile(filePath)
if err != nil {
return xerrors.Errorf("read dbmem: %w", err)
}
// Required to preserve imports!
f, err := decorator.NewDecoratorWithImports(token.NewFileSet(), packageName, goast.New()).Parse(contents)
if err != nil {
return xerrors.Errorf("parse dbmem: %w", err)
}
pointer := false
for i := 0; i < len(f.Decls); i++ {
funcDecl, ok := f.Decls[i].(*dst.FuncDecl)
if !ok || funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 {
continue
}
var ident *dst.Ident
switch t := funcDecl.Recv.List[0].Type.(type) {
case *dst.Ident:
ident = t
case *dst.StarExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
pointer = true
}
if ident == nil || ident.Name != structName {
continue
}
if _, ok := funcByName[funcDecl.Name.Name]; !ok {
continue
}
declByName[funcDecl.Name.Name] = funcDecl
f.Decls = append(f.Decls[:i], f.Decls[i+1:]...)
i--
}
for _, fn := range funcs {
var bodyStmts []dst.Stmt
// Add input validation, only relevant for dbmem.
if strings.Contains(filePath, "dbmem") && len(fn.Func.Params.List) == 2 && fn.Func.Params.List[1].Names[0].Name == "arg" {
/*
err := validateDatabaseType(arg)
if err != nil {
return database.User{}, err
}
*/
bodyStmts = append(bodyStmts, &dst.AssignStmt{
Lhs: []dst.Expr{dst.NewIdent("err")},
Tok: token.DEFINE,
Rhs: []dst.Expr{
&dst.CallExpr{
Fun: &dst.Ident{
Name: "validateDatabaseType",
},
Args: []dst.Expr{dst.NewIdent("arg")},
},
},
})
returnStmt := &dst.ReturnStmt{
Results: []dst.Expr{}, // Filled below.
}
bodyStmts = append(bodyStmts, &dst.IfStmt{
Cond: &dst.BinaryExpr{
X: dst.NewIdent("err"),
Op: token.NEQ,
Y: dst.NewIdent("nil"),
},
Body: &dst.BlockStmt{
List: []dst.Stmt{
returnStmt,
},
},
Decs: dst.IfStmtDecorations{
NodeDecs: dst.NodeDecs{
After: dst.EmptyLine,
},
},
})
for _, r := range fn.Func.Results.List {
switch typ := r.Type.(type) {
case *dst.StarExpr, *dst.ArrayType, *dst.SelectorExpr:
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("nil"))
case *dst.Ident:
if typ.Path != "" {
returnStmt.Results = append(returnStmt.Results, dst.NewIdent(fmt.Sprintf("%s.%s{}", path.Base(typ.Path), typ.Name)))
} else {
switch typ.Name {
case "uint8", "uint16", "uint32", "uint64", "uint", "uintptr",
"int8", "int16", "int32", "int64", "int",
"byte", "rune",
"float32", "float64",
"complex64", "complex128":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("0"))
case "string":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("\"\""))
case "bool":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("false"))
case "error":
returnStmt.Results = append(returnStmt.Results, dst.NewIdent("err"))
default:
panic(fmt.Sprintf("unknown ident: %#v", r.Type))
}
}
default:
panic(fmt.Sprintf("unknown return type: %T", r.Type))
}
}
}
decl, ok := declByName[fn.Name]
if !ok {
typeName := structName
if pointer {
typeName = "*" + typeName
}
params := make([]string, 0)
if fn.Func.Params != nil {
for _, p := range fn.Func.Params.List {
for _, name := range p.Names {
params = append(params, name.Name)
}
}
}
returns := make([]string, 0)
if fn.Func.Results != nil {
for i := range fn.Func.Results.List {
returns = append(returns, fmt.Sprintf("r%d", i))
}
}
funcDecl, err := compileFuncDecl(stub(stubParams{
FuncName: fn.Name,
Parameters: strings.Join(params, ","),
Returns: strings.Join(returns, ","),
}))
if err != nil {
return xerrors.Errorf("compile func decl: %w", err)
}
// Not implemented!
decl = &dst.FuncDecl{
Name: dst.NewIdent(fn.Name),
Type: &dst.FuncType{
Func: true,
TypeParams: fn.Func.TypeParams,
Params: fn.Func.Params,
Results: fn.Func.Results,
Decs: fn.Func.Decs,
},
Recv: &dst.FieldList{
List: []*dst.Field{{
Names: []*dst.Ident{dst.NewIdent(receiver)},
Type: dst.NewIdent(typeName),
}},
},
Decs: dst.FuncDeclDecorations{
NodeDecs: dst.NodeDecs{
Before: dst.EmptyLine,
After: dst.EmptyLine,
},
},
Body: &dst.BlockStmt{
List: append(bodyStmts, funcDecl.Body.List...),
},
}
}
if ok {
for i, pm := range fn.Func.Params.List {
if len(decl.Type.Params.List) < i+1 {
decl.Type.Params.List = append(decl.Type.Params.List, pm)
}
if !reflect.DeepEqual(decl.Type.Params.List[i].Type, pm.Type) {
decl.Type.Params.List[i].Type = pm.Type
}
}
for i, res := range fn.Func.Results.List {
if len(decl.Type.Results.List) < i+1 {
decl.Type.Results.List = append(decl.Type.Results.List, res)
}
if !reflect.DeepEqual(decl.Type.Results.List[i].Type, res.Type) {
decl.Type.Results.List[i].Type = res.Type
}
}
}
f.Decls = append(f.Decls, decl)
}
// Required to preserve imports!
restorer := decorator.NewRestorerWithImports(packageName, guess.New())
restored, err := restorer.RestoreFile(f)
if err != nil {
return xerrors.Errorf("restore package: %w", err)
}
var buf bytes.Buffer
err = format.Node(&buf, restorer.Fset, restored)
if err != nil {
return xerrors.Errorf("format package: %w", err)
}
data, err := imports.Process(filePath, buf.Bytes(), &imports.Options{
Comments: true,
})
if err != nil {
return xerrors.Errorf("process imports: %w", err)
}
return os.WriteFile(filePath, data, 0o600)
}
// compileFuncDecl extracts the function declaration from the given code.
func compileFuncDecl(code string) (*dst.FuncDecl, error) {
f, err := decorator.Parse(fmt.Sprintf(`package stub
func stub() {
%s
}`, strings.TrimSpace(code)))
if err != nil {
return nil, err
}
if len(f.Decls) != 1 {
return nil, xerrors.Errorf("expected 1 decl, got %d", len(f.Decls))
}
decl, ok := f.Decls[0].(*dst.FuncDecl)
if !ok {
return nil, xerrors.Errorf("expected func decl, got %T", f.Decls[0])
}
return decl, nil
}
type querierFunction struct {
// Name is the name of the function. Like "GetUserByID"
Name string
// Func is the AST representation of a function.
Func *dst.FuncType
}
// readQuerierFunctions reads the functions from coderd/database/querier.go
func readQuerierFunctions() ([]querierFunction, error) {
f, err := parseDBFile("querier.go")
if err != nil {
return nil, xerrors.Errorf("parse querier.go: %w", err)
}
funcs, err := loadInterfaceFuncs(f, "sqlcQuerier")
if err != nil {
return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err)
}
customFile, err := parseDBFile("modelqueries.go")
if err != nil {
return nil, xerrors.Errorf("parse modelqueriers.go: %w", err)
}
// Custom funcs should be appended after the regular functions
customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier")
if err != nil {
return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err)
}
return append(funcs, customFuncs...), nil
}
func parseDBFile(filename string) (*dst.File, error) {
localPath, err := localFilePath()
if err != nil {
return nil, err
}
querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename)
querierData, err := os.ReadFile(querierPath)
if err != nil {
return nil, xerrors.Errorf("read %s: %w", filename, err)
}
f, err := decorator.Parse(querierData)
return f, err
}
func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) {
var querier *dst.InterfaceType
for _, decl := range f.Decls {
genDecl, ok := decl.(*dst.GenDecl)
if !ok {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*dst.TypeSpec)
if !ok {
continue
}
// This is the name of the interface. If that ever changes,
// this will need to be updated.
if typeSpec.Name.Name != interfaceName {
continue
}
querier, ok = typeSpec.Type.(*dst.InterfaceType)
if !ok {
return nil, xerrors.Errorf("unexpected sqlcQuerier type: %T", typeSpec.Type)
}
break
}
}
if querier == nil {
return nil, xerrors.Errorf("querier not found")
}
funcs := []querierFunction{}
allMethods := interfaceMethods(querier)
for _, method := range allMethods {
funcType, ok := method.Type.(*dst.FuncType)
if !ok {
continue
}
for _, t := range []*dst.FieldList{funcType.Params, funcType.Results, funcType.TypeParams} {
if t == nil {
continue
}
for _, f := range t.List {
var ident *dst.Ident
switch t := f.Type.(type) {
case *dst.Ident:
ident = t
case *dst.StarExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
case *dst.SelectorExpr:
ident, ok = t.X.(*dst.Ident)
if !ok {
continue
}
case *dst.ArrayType:
ident, ok = t.Elt.(*dst.Ident)
if !ok {
continue
}
}
if ident == nil {
continue
}
// If the type is exported then we should be able to find it
// in the database package!
if !ident.IsExported() {
continue
}
ident.Path = "github.com/coder/coder/v2/coderd/database"
}
}
funcs = append(funcs, querierFunction{
Name: method.Names[0].Name,
Func: funcType,
})
}
return funcs, nil
}
// localFilePath returns the location of `main.go` in the dbgen package.
func localFilePath() (string, error) {
_, filename, _, ok := runtime.Caller(0)
if !ok {
return "", xerrors.Errorf("failed to get caller")
}
return filename, nil
}
// nameFromSnakeCase converts snake_case to CamelCase.
func nameFromSnakeCase(s string) string {
var ret string
for _, ss := range strings.Split(s, "_") {
switch ss {
case "id":
ret += "ID"
case "ids":
ret += "IDs"
case "jwt":
ret += "JWT"
case "idx":
ret += "Index"
case "api":
ret += "API"
case "uuid":
ret += "UUID"
case "gitsshkeys":
ret += "GitSSHKeys"
case "fkey":
// ignore
default:
ret += strings.Title(ss)
}
}
return ret
}
// interfaceMethods returns all embedded methods of an interface.
func interfaceMethods(i *dst.InterfaceType) []*dst.Field {
var allMethods []*dst.Field
for _, field := range i.Methods.List {
switch fieldType := field.Type.(type) {
case *dst.FuncType:
allMethods = append(allMethods, field)
case *dst.InterfaceType:
allMethods = append(allMethods, interfaceMethods(fieldType)...)
case *dst.Ident:
// Embedded interfaces are Idents -> TypeSpec -> InterfaceType
// If the embedded interface is not in the parsed file, then
// the Obj will be nil.
if fieldType.Obj != nil {
objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec)
if ok {
isInterface, ok := objDecl.Type.(*dst.InterfaceType)
if ok {
allMethods = append(allMethods, interfaceMethods(isInterface)...)
}
}
}
}
}
return allMethods
}