mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
chore(coderd): add MockAuditor.Contains test helper (#10421)
* Adds a Contains() method on MockAuditor to help with asserting the presence of an audit log with specific fields. * Updates existing usages of verifyAuditWorkspaceCreated to use the new helper * Updates test referenced in PR#10396.
This commit is contained in:
@ -3,6 +3,10 @@ package audit
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
@ -68,3 +72,76 @@ func (a *MockAuditor) Export(_ context.Context, alog database.AuditLog) error {
|
||||
func (*MockAuditor) diff(any, any) Map {
|
||||
return Map{}
|
||||
}
|
||||
|
||||
// Contains returns true if, for each non-zero-valued field in expected,
|
||||
// there exists a corresponding audit log in the mock auditor that matches
|
||||
// the expected values. Returns false otherwise.
|
||||
func (a *MockAuditor) Contains(t testing.TB, expected database.AuditLog) bool {
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
for idx, al := range a.auditLogs {
|
||||
if expected.ID != uuid.Nil && al.ID != expected.ID {
|
||||
t.Logf("audit log %d: expected ID %s, got %s", idx+1, expected.ID, al.ID)
|
||||
continue
|
||||
}
|
||||
if !expected.Time.IsZero() && expected.Time != al.Time {
|
||||
t.Logf("audit log %d: expected Time %s, got %s", idx+1, expected.Time, al.Time)
|
||||
continue
|
||||
}
|
||||
if expected.UserID != uuid.Nil && al.UserID != expected.UserID {
|
||||
t.Logf("audit log %d: expected UserID %s, got %s", idx+1, expected.UserID, al.UserID)
|
||||
continue
|
||||
}
|
||||
if expected.OrganizationID != uuid.Nil && al.UserID != expected.UserID {
|
||||
t.Logf("audit log %d: expected OrganizationID %s, got %s", idx+1, expected.OrganizationID, al.OrganizationID)
|
||||
continue
|
||||
}
|
||||
if expected.Ip.Valid && al.Ip.IPNet.String() != expected.Ip.IPNet.String() {
|
||||
t.Logf("audit log %d: expected Ip %s, got %s", idx+1, expected.Ip.IPNet, al.Ip.IPNet)
|
||||
continue
|
||||
}
|
||||
if expected.UserAgent.Valid && al.UserAgent.String != expected.UserAgent.String {
|
||||
t.Logf("audit log %d: expected UserAgent %s, got %s", idx+1, expected.UserAgent.String, al.UserAgent.String)
|
||||
continue
|
||||
}
|
||||
if expected.ResourceType != "" && expected.ResourceType != al.ResourceType {
|
||||
t.Logf("audit log %d: expected ResourceType %s, got %s", idx+1, expected.ResourceType, al.ResourceType)
|
||||
continue
|
||||
}
|
||||
if expected.ResourceID != uuid.Nil && expected.ResourceID != al.ResourceID {
|
||||
t.Logf("audit log %d: expected ResourceID %s, got %s", idx+1, expected.ResourceID, al.ResourceID)
|
||||
continue
|
||||
}
|
||||
if expected.ResourceTarget != "" && expected.ResourceTarget != al.ResourceTarget {
|
||||
t.Logf("audit log %d: expected ResourceTarget %s, got %s", idx+1, expected.ResourceTarget, al.ResourceTarget)
|
||||
continue
|
||||
}
|
||||
if expected.Action != "" && expected.Action != al.Action {
|
||||
t.Logf("audit log %d: expected Action %s, got %s", idx+1, expected.Action, al.Action)
|
||||
continue
|
||||
}
|
||||
if len(expected.Diff) > 0 && slices.Compare(expected.Diff, al.Diff) != 0 {
|
||||
t.Logf("audit log %d: expected Diff %s, got %s", idx+1, string(expected.Diff), string(al.Diff))
|
||||
continue
|
||||
}
|
||||
if expected.StatusCode != 0 && expected.StatusCode != al.StatusCode {
|
||||
t.Logf("audit log %d: expected StatusCode %d, got %d", idx+1, expected.StatusCode, al.StatusCode)
|
||||
continue
|
||||
}
|
||||
if len(expected.AdditionalFields) > 0 && slices.Compare(expected.AdditionalFields, al.AdditionalFields) != 0 {
|
||||
t.Logf("audit log %d: expected AdditionalFields %s, got %s", idx+1, string(expected.AdditionalFields), string(al.AdditionalFields))
|
||||
continue
|
||||
}
|
||||
if expected.RequestID != uuid.Nil && expected.RequestID != al.RequestID {
|
||||
t.Logf("audit log %d: expected RequestID %s, got %s", idx+1, expected.RequestID, al.RequestID)
|
||||
continue
|
||||
}
|
||||
if expected.ResourceIcon != "" && expected.ResourceIcon != al.ResourceIcon {
|
||||
t.Logf("audit log %d: expected ResourceIcon %s, got %s", idx+1, expected.ResourceIcon, al.ResourceIcon)
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -511,7 +511,11 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
verifyAuditWorkspaceCreated(t, auditor, workspace.Name)
|
||||
assert.True(t, auditor.Contains(t, database.AuditLog{
|
||||
ResourceType: database.ResourceTypeWorkspace,
|
||||
Action: database.AuditActionCreate,
|
||||
ResourceTarget: workspace.Name,
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("CreateFromVersionWithAuditLogs", func(t *testing.T) {
|
||||
@ -535,7 +539,11 @@ func TestPostWorkspacesByOrganization(t *testing.T) {
|
||||
|
||||
require.Equal(t, testWorkspaceBuild.TemplateVersionID, versionTest.ID)
|
||||
require.Equal(t, defaultWorkspaceBuild.TemplateVersionID, versionDefault.ID)
|
||||
verifyAuditWorkspaceCreated(t, auditor, defaultWorkspace.Name)
|
||||
assert.True(t, auditor.Contains(t, database.AuditLog{
|
||||
ResourceType: database.ResourceTypeWorkspace,
|
||||
Action: database.AuditActionCreate,
|
||||
ResourceTarget: defaultWorkspace.Name,
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("InvalidCombinationOfTemplateAndTemplateVersion", func(t *testing.T) {
|
||||
@ -2741,7 +2749,11 @@ func TestWorkspaceDormant(t *testing.T) {
|
||||
Dormant: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, auditRecorder.AuditLogs(), 1)
|
||||
require.True(t, auditRecorder.Contains(t, database.AuditLog{
|
||||
Action: database.AuditActionWrite,
|
||||
ResourceType: database.ResourceTypeWorkspace,
|
||||
ResourceTarget: workspace.Name,
|
||||
}))
|
||||
|
||||
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
|
||||
require.NoError(t, err, "fetch provisioned workspace")
|
||||
@ -2804,25 +2816,3 @@ func TestWorkspaceDormant(t *testing.T) {
|
||||
coderdtest.MustTransitionWorkspace(t, client, workspace.ID, database.WorkspaceTransitionStop, database.WorkspaceTransitionStart)
|
||||
})
|
||||
}
|
||||
|
||||
func verifyAuditWorkspaceCreated(t *testing.T, auditor *audit.MockAuditor, workspaceName string) {
|
||||
var auditLogs []database.AuditLog
|
||||
ok := assert.Eventually(t, func() bool {
|
||||
auditLogs = auditor.AuditLogs()
|
||||
|
||||
for _, auditLog := range auditLogs {
|
||||
if auditLog.Action == database.AuditActionCreate &&
|
||||
auditLog.ResourceType == database.ResourceTypeWorkspace &&
|
||||
auditLog.ResourceTarget == workspaceName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
if !ok {
|
||||
for i, auditLog := range auditLogs {
|
||||
t.Logf("%d. Audit: ID=%s action=%s resourceID=%s resourceType=%s resourceTarget=%s", i+1, auditLog.ID, auditLog.Action, auditLog.ResourceID, auditLog.ResourceType, auditLog.ResourceTarget)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user