From b2892c3d177b6f28d05af2ecbbff1fa16bfcde11 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 6 Apr 2023 16:16:53 -0500 Subject: [PATCH] test: Increase test coverage on auditable resources (#7038) * test: Increase test coverage on auditable resources When adding a new audit resource, we also need to add it to the function switch statements. This is a likely mistake, now a unit test will check this for you --- coderd/audit/request.go | 10 +- coderd/database/dump.sql | 3 +- .../000115_workspace_proxy_resource.down.sql | 2 + .../000115_workspace_proxy_resource.up.sql | 1 + coderd/database/models.go | 5 +- enterprise/audit/table_internal_test.go | 98 ++++++++++++++++++- 6 files changed, 114 insertions(+), 5 deletions(-) create mode 100644 coderd/database/migrations/000115_workspace_proxy_resource.down.sql create mode 100644 coderd/database/migrations/000115_workspace_proxy_resource.up.sql diff --git a/coderd/audit/request.go b/coderd/audit/request.go index 98359803b4..4700d1a4de 100644 --- a/coderd/audit/request.go +++ b/coderd/audit/request.go @@ -78,6 +78,8 @@ func ResourceTarget[T Auditable](tgt T) string { return "" case database.License: return strconv.Itoa(int(typed.ID)) + case database.WorkspaceProxy: + return typed.Name default: panic(fmt.Sprintf("unknown resource %T", tgt)) } @@ -103,13 +105,15 @@ func ResourceID[T Auditable](tgt T) uuid.UUID { return typed.UserID case database.License: return typed.UUID + case database.WorkspaceProxy: + return typed.ID default: panic(fmt.Sprintf("unknown resource %T", tgt)) } } func ResourceType[T Auditable](tgt T) database.ResourceType { - switch any(tgt).(type) { + switch typed := any(tgt).(type) { case database.Template: return database.ResourceTypeTemplate case database.TemplateVersion: @@ -128,8 +132,10 @@ func ResourceType[T Auditable](tgt T) database.ResourceType { return database.ResourceTypeApiKey case database.License: return database.ResourceTypeLicense + case database.WorkspaceProxy: + return database.ResourceTypeWorkspaceProxy default: - panic(fmt.Sprintf("unknown resource %T", tgt)) + panic(fmt.Sprintf("unknown resource %T", typed)) } } diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 71397f4cfb..964cb7fe17 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -94,7 +94,8 @@ CREATE TYPE resource_type AS ENUM ( 'api_key', 'group', 'workspace_build', - 'license' + 'license', + 'workspace_proxy' ); CREATE TYPE user_status AS ENUM ( diff --git a/coderd/database/migrations/000115_workspace_proxy_resource.down.sql b/coderd/database/migrations/000115_workspace_proxy_resource.down.sql new file mode 100644 index 0000000000..d1d1637f4f --- /dev/null +++ b/coderd/database/migrations/000115_workspace_proxy_resource.down.sql @@ -0,0 +1,2 @@ +-- It's not possible to drop enum values from enum types, so the UP has "IF NOT +-- EXISTS". diff --git a/coderd/database/migrations/000115_workspace_proxy_resource.up.sql b/coderd/database/migrations/000115_workspace_proxy_resource.up.sql new file mode 100644 index 0000000000..adc22ad0ba --- /dev/null +++ b/coderd/database/migrations/000115_workspace_proxy_resource.up.sql @@ -0,0 +1 @@ +ALTER TYPE resource_type ADD VALUE IF NOT EXISTS 'workspace_proxy'; diff --git a/coderd/database/models.go b/coderd/database/models.go index a0b11b2d3b..c4c830ef14 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -884,6 +884,7 @@ const ( ResourceTypeGroup ResourceType = "group" ResourceTypeWorkspaceBuild ResourceType = "workspace_build" ResourceTypeLicense ResourceType = "license" + ResourceTypeWorkspaceProxy ResourceType = "workspace_proxy" ) func (e *ResourceType) Scan(src interface{}) error { @@ -932,7 +933,8 @@ func (e ResourceType) Valid() bool { ResourceTypeApiKey, ResourceTypeGroup, ResourceTypeWorkspaceBuild, - ResourceTypeLicense: + ResourceTypeLicense, + ResourceTypeWorkspaceProxy: return true } return false @@ -950,6 +952,7 @@ func AllResourceTypeValues() []ResourceType { ResourceTypeGroup, ResourceTypeWorkspaceBuild, ResourceTypeLicense, + ResourceTypeWorkspaceProxy, } } diff --git a/enterprise/audit/table_internal_test.go b/enterprise/audit/table_internal_test.go index 7ba9c48598..77e75257d3 100644 --- a/enterprise/audit/table_internal_test.go +++ b/enterprise/audit/table_internal_test.go @@ -1,21 +1,29 @@ package audit import ( + "fmt" "go/types" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/tools/go/packages" + + "github.com/coder/coder/coderd/audit" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/util/slice" ) // TestAuditableResources ensures that all auditable resources are included in // the Auditable interface and vice versa. +// +//nolint:tparallel func TestAuditableResources(t *testing.T) { t.Parallel() pkgs, err := packages.Load(&packages.Config{ - Mode: packages.NeedTypes, + Mode: packages.NeedTypes | packages.NeedDeps, }, "../../coderd/audit") require.NoError(t, err) @@ -37,6 +45,7 @@ func TestAuditableResources(t *testing.T) { require.True(t, ok, "expected Auditable to be a union") found := make(map[string]bool) + expectedList := make([]string, 0) // Now we check we have all the resources in the AuditableResources for i := 0; i < unionType.Len(); i++ { // All types come across like 'github.com/coder/coder/coderd/database.' @@ -44,6 +53,7 @@ func TestAuditableResources(t *testing.T) { _, ok := AuditableResources[typeName] assert.True(t, ok, "missing resource %q from AuditableResources", typeName) found[typeName] = true + expectedList = append(expectedList, typeName) } // Also check that all resources in the table are in the union. We could @@ -52,4 +62,90 @@ func TestAuditableResources(t *testing.T) { _, ok := found[name] assert.True(t, ok, "extra resource %q found in AuditableResources", name) } + + // Various functions that have switch statements to include all Auditable + // resources. Make sure we have all types supported. + // nolint:paralleltest + t.Run("ResourceID", func(t *testing.T) { + // The function being tested, provided here to make it easier to find + _ = audit.ResourceID[database.APIKey] + testAuditFunctionWithSwitch(t, auditPkg, "ResourceID", expectedList) + }) + + // nolint:paralleltest + t.Run("ResourceType", func(t *testing.T) { + // The function being tested, provided here to make it easier to find + _ = audit.ResourceType[database.APIKey] + testAuditFunctionWithSwitch(t, auditPkg, "ResourceType", expectedList) + }) + + // nolint:paralleltest + t.Run("ResourceTarget", func(t *testing.T) { + // The function being tested, provided here to make it easier to find + _ = audit.ResourceTarget[database.APIKey] + testAuditFunctionWithSwitch(t, auditPkg, "ResourceTarget", expectedList) + }) +} + +// testAuditFunctionWithSwitch is a helper function to test that a function has +// a typed switch statement that includes all the types in expectedTypes. +func testAuditFunctionWithSwitch(t *testing.T, pkg *packages.Package, funcName string, expectedTypes []string) { + t.Helper() + + f, ok := pkg.Types.Scope().Lookup(funcName).(*types.Func) + require.True(t, ok, fmt.Sprintf("expected %s to be a function", funcName)) + switchCases := findSwitchTypes(f) + for _, expected := range expectedTypes { + if !slice.Contains(switchCases, expected) { + t.Errorf("%s switch statement is missing type %q. Include it in the switch case block", funcName, expected) + } + } + for _, sc := range switchCases { + if !slice.Contains(expectedTypes, sc) { + t.Errorf("%s switch statement has unexpected type %q. Remove it from the switch case block", funcName, sc) + } + } +} + +// findSwitchTypes is a helper function to find all types a switch statement in +// the function body of f has. +func findSwitchTypes(f *types.Func) []string { + caseTypes := make([]string, 0) + switches := returnSwitchBlocks(f.Scope()) + for _, sc := range switches { + scTypes := findCaseTypes(sc) + caseTypes = append(caseTypes, scTypes...) + } + return caseTypes +} + +func returnSwitchBlocks(sc *types.Scope) []*types.Scope { + switches := make([]*types.Scope, 0) + for i := 0; i < sc.NumChildren(); i++ { + child := sc.Child(i) + cStr := child.String() + // This is the easiest way to tell if it is a switch statement. + if strings.Contains(cStr, "type switch scope") { + switches = append(switches, child) + } + } + return switches +} + +// findCaseTypes returns all case types in a typed switch statement. Excluding +// the "Default:" case. +func findCaseTypes(sc *types.Scope) []string { + caseTypes := make([]string, 0) + for i := 0; i < sc.NumChildren(); i++ { + child := sc.Child(i) + for _, name := range child.Names() { + obj := child.Lookup(name).Type() + typeName := obj.String() + // Ignore the "Default:" case + if typeName != "any" { + caseTypes = append(caseTypes, typeName) + } + } + } + return caseTypes }