chore: convert dbauthz tests to also run with Postgres (#15862)

Another PR to address https://github.com/coder/coder/issues/15109.

- adds the DisableForeignKeysAndTriggers utility, which simplifies
converting tests from in-mem to postgres
- converts the dbauthz test suite to pass on both the in-mem db and
Postgres
This commit is contained in:
Hugo Dutka
2025-01-08 16:22:51 +01:00
committed by GitHub
parent 13cfaae619
commit 106b1cd3bc
13 changed files with 1680 additions and 337 deletions

View File

@ -8,6 +8,7 @@ import (
"fmt"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
@ -1366,6 +1367,13 @@ func (q *querier) DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context,
return q.db.DeleteWorkspaceAgentPortSharesByTemplate(ctx, templateID)
}
func (q *querier) DisableForeignKeysAndTriggers(ctx context.Context) error {
if !testing.Testing() {
return xerrors.Errorf("DisableForeignKeysAndTriggers is only allowed in tests")
}
return q.db.DisableForeignKeysAndTriggers(ctx)
}
func (q *querier) EnqueueNotificationMessage(ctx context.Context, arg database.EnqueueNotificationMessageParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceNotificationMessage); err != nil {
return err

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,6 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/rbac"
)
@ -22,13 +21,9 @@ import (
func TestGroupsAuth(t *testing.T) {
t.Parallel()
if dbtestutil.WillUsePostgres() {
t.Skip("this test would take too long to run on postgres")
}
authz := rbac.NewAuthorizer(prometheus.NewRegistry())
db := dbauthz.New(dbmem.New(), authz, slogtest.Make(t, &slogtest.Options{
store, _ := dbtestutil.NewDB(t)
db := dbauthz.New(store, authz, slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}), coderdtest.AccessControlStorePointer())

View File

@ -22,8 +22,8 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/regosql"
"github.com/coder/coder/v2/coderd/util/slice"
@ -114,7 +114,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
methodName := names[len(names)-1]
s.methodAccounting[methodName]++
db := dbmem.New()
db, _ := dbtestutil.NewDB(t)
fakeAuthorizer := &coderdtest.FakeAuthorizer{}
rec := &coderdtest.RecordingAuthorizer{
Wrapped: fakeAuthorizer,
@ -217,7 +217,11 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
}
}
rec.AssertActor(s.T(), actor, pairs...)
if testCase.outOfOrder {
rec.AssertOutOfOrder(s.T(), actor, pairs...)
} else {
rec.AssertActor(s.T(), actor, pairs...)
}
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
})
}
@ -236,6 +240,8 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("NotAuthorized", func() {
az.AlwaysReturn(rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil))
// Override the SQL filter to always fail.
az.OverrideSQLFilter("FALSE")
// If we have assertions, that means the method should FAIL
// if RBAC will disallow the request. The returned error should
@ -328,6 +334,14 @@ type expects struct {
notAuthorizedExpect string
cancelledCtxExpect string
successAuthorizer func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error
outOfOrder bool
}
// OutOfOrder is optional. It controls whether the assertions should be
// asserted in order.
func (m *expects) OutOfOrder() *expects {
m.outOfOrder = true
return m
}
// Asserts is required. Asserts the RBAC authorize calls that should be made.
@ -358,6 +372,24 @@ func (m *expects) Errors(err error) *expects {
return m
}
// ErrorsWithPG is optional. If it is never called, it will not be asserted.
// It will only be asserted if the test is running with a Postgres database.
func (m *expects) ErrorsWithPG(err error) *expects {
if dbtestutil.WillUsePostgres() {
return m.Errors(err)
}
return m
}
// ErrorsWithInMemDB is optional. If it is never called, it will not be asserted.
// It will only be asserted if the test is running with an in-memory database.
func (m *expects) ErrorsWithInMemDB(err error) *expects {
if !dbtestutil.WillUsePostgres() {
return m.Errors(err)
}
return m
}
func (m *expects) FailSystemObjectChecks() *expects {
return m.WithSuccessAuthorizer(func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error {
if obj.Type == rbac.ResourceSystem.Type {

View File

@ -2206,6 +2206,11 @@ func (q *FakeQuerier) DeleteWorkspaceAgentPortSharesByTemplate(_ context.Context
return nil
}
func (*FakeQuerier) DisableForeignKeysAndTriggers(_ context.Context) error {
// This is a no-op in the in-memory database.
return nil
}
func (q *FakeQuerier) EnqueueNotificationMessage(_ context.Context, arg database.EnqueueNotificationMessageParams) error {
err := validateDatabaseType(arg)
if err != nil {

View File

@ -413,6 +413,13 @@ func (m queryMetricsStore) DeleteWorkspaceAgentPortSharesByTemplate(ctx context.
return r0
}
func (m queryMetricsStore) DisableForeignKeysAndTriggers(ctx context.Context) error {
start := time.Now()
r0 := m.s.DisableForeignKeysAndTriggers(ctx)
m.queryLatencies.WithLabelValues("DisableForeignKeysAndTriggers").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) EnqueueNotificationMessage(ctx context.Context, arg database.EnqueueNotificationMessageParams) error {
start := time.Now()
r0 := m.s.EnqueueNotificationMessage(ctx, arg)

View File

@ -728,6 +728,20 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceAgentPortSharesByTemplate(arg0,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceAgentPortSharesByTemplate", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceAgentPortSharesByTemplate), arg0, arg1)
}
// DisableForeignKeysAndTriggers mocks base method.
func (m *MockStore) DisableForeignKeysAndTriggers(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisableForeignKeysAndTriggers", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// DisableForeignKeysAndTriggers indicates an expected call of DisableForeignKeysAndTriggers.
func (mr *MockStoreMockRecorder) DisableForeignKeysAndTriggers(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisableForeignKeysAndTriggers", reflect.TypeOf((*MockStore)(nil).DisableForeignKeysAndTriggers), arg0)
}
// EnqueueNotificationMessage mocks base method.
func (m *MockStore) EnqueueNotificationMessage(arg0 context.Context, arg1 database.EnqueueNotificationMessageParams) error {
m.ctrl.T.Helper()

View File

@ -87,6 +87,18 @@ func NewDBWithSQLDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub
return db, ps, sqlDB
}
var DefaultTimezone = "Canada/Newfoundland"
// NowInDefaultTimezone returns the current time rounded to the nearest microsecond in the default timezone
// used by postgres in tests. Useful for object equality checks.
func NowInDefaultTimezone() time.Time {
loc, err := time.LoadLocation(DefaultTimezone)
if err != nil {
panic(err)
}
return time.Now().In(loc).Round(time.Microsecond)
}
func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) {
t.Helper()
@ -115,7 +127,7 @@ func NewDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub) {
// - It has a non-UTC offset
// - It has a fractional hour UTC offset
// - It includes a daylight savings time component
o.fixedTimezone = "Canada/Newfoundland"
o.fixedTimezone = DefaultTimezone
}
dbName := dbNameFromConnectionURL(t, connectionURL)
setDBTimezone(t, connectionURL, dbName, o.fixedTimezone)
@ -318,3 +330,15 @@ func normalizeDump(schema []byte) []byte {
return schema
}
// Deprecated: disable foreign keys was created to aid in migrating off
// of the test-only in-memory database. Do not use this in new code.
func DisableForeignKeysAndTriggers(t *testing.T, db database.Store) {
err := db.DisableForeignKeysAndTriggers(context.Background())
if t != nil {
require.NoError(t, err)
}
if err != nil {
panic(err)
}
}

View File

@ -106,6 +106,10 @@ type sqlcQuerier interface {
DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error)
DeleteWorkspaceAgentPortShare(ctx context.Context, arg DeleteWorkspaceAgentPortShareParams) error
DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context, templateID uuid.UUID) error
// Disable foreign keys and triggers for all tables.
// Deprecated: disable foreign keys was created to aid in migrating off
// of the test-only in-memory database. Do not use this in new code.
DisableForeignKeysAndTriggers(ctx context.Context) error
EnqueueNotificationMessage(ctx context.Context, arg EnqueueNotificationMessageParams) error
FavoriteWorkspace(ctx context.Context, id uuid.UUID) error
// This is used to build up the notification_message's JSON payload.

View File

@ -9796,6 +9796,33 @@ func (q *sqlQuerier) InsertTemplateVersionWorkspaceTag(ctx context.Context, arg
return i, err
}
const disableForeignKeysAndTriggers = `-- name: DisableForeignKeysAndTriggers :exec
DO $$
DECLARE
table_record record;
BEGIN
FOR table_record IN
SELECT table_schema, table_name
FROM information_schema.tables
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_type = 'BASE TABLE'
LOOP
EXECUTE format('ALTER TABLE %I.%I DISABLE TRIGGER ALL',
table_record.table_schema,
table_record.table_name);
END LOOP;
END;
$$
`
// Disable foreign keys and triggers for all tables.
// Deprecated: disable foreign keys was created to aid in migrating off
// of the test-only in-memory database. Do not use this in new code.
func (q *sqlQuerier) DisableForeignKeysAndTriggers(ctx context.Context) error {
_, err := q.db.ExecContext(ctx, disableForeignKeysAndTriggers)
return err
}
const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one
SELECT
user_links.user_id, user_links.login_type, user_links.linked_id, user_links.oauth_access_token, user_links.oauth_refresh_token, user_links.oauth_expiry, user_links.oauth_access_token_key_id, user_links.oauth_refresh_token_key_id, user_links.claims

View File

@ -0,0 +1,20 @@
-- name: DisableForeignKeysAndTriggers :exec
-- Disable foreign keys and triggers for all tables.
-- Deprecated: disable foreign keys was created to aid in migrating off
-- of the test-only in-memory database. Do not use this in new code.
DO $$
DECLARE
table_record record;
BEGIN
FOR table_record IN
SELECT table_schema, table_name
FROM information_schema.tables
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
AND table_type = 'BASE TABLE'
LOOP
EXECUTE format('ALTER TABLE %I.%I DISABLE TRIGGER ALL',
table_record.table_schema,
table_record.table_name);
END LOOP;
END;
$$;