feat: Add initial AuthzQuerier implementation (#5919)

feat: Add initial AuthzQuerier implementation
- Adds package database/dbauthz that adds a database.Store implementation where each method goes through AuthZ checks
- Implements all database.Store methods on AuthzQuerier
- Updates and fixes unit tests where required
- Updates coderd initialization to use AuthzQuerier if codersdk.ExperimentAuthzQuerier is enabled
This commit is contained in:
Steven Masley
2023-02-14 08:27:06 -06:00
committed by GitHub
parent ebdfdc749d
commit 6fb8aff6d0
59 changed files with 5013 additions and 136 deletions

View File

@ -0,0 +1,387 @@
package dbauthz
import (
"context"
"database/sql"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/open-policy-agent/opa/topdown"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/rbac"
)
var _ database.Store = (*querier)(nil)
var (
// NoActorError wraps ErrNoRows for the api to return a 404. This is the correct
// response when the user is not authorized.
NoActorError = xerrors.Errorf("no authorization actor in context: %w", sql.ErrNoRows)
)
// NotAuthorizedError is a sentinel error that unwraps to sql.ErrNoRows.
// This allows the internal error to be read by the caller if needed. Otherwise
// it will be handled as a 404.
type NotAuthorizedError struct {
Err error
}
func (e NotAuthorizedError) Error() string {
return fmt.Sprintf("unauthorized: %s", e.Err.Error())
}
// Unwrap will always unwrap to a sql.ErrNoRows so the API returns a 404.
// So 'errors.Is(err, sql.ErrNoRows)' will always be true.
func (NotAuthorizedError) Unwrap() error {
return sql.ErrNoRows
}
func logNotAuthorizedError(ctx context.Context, logger slog.Logger, err error) error {
// Only log the errors if it is an UnauthorizedError error.
internalError := new(rbac.UnauthorizedError)
if err != nil && xerrors.As(err, &internalError) {
e := new(topdown.Error)
if xerrors.As(err, &e) || e.Code == topdown.CancelErr {
// For some reason rego changes a cancelled context to a topdown.CancelErr. We
// expect to check for cancelled context errors if the user cancels the request,
// so we should change the error to a context.Canceled error.
//
// NotAuthorizedError is == to sql.ErrNoRows, which is not correct
// if it's actually a cancelled context.
internalError.SetInternal(context.Canceled)
return internalError
}
logger.Debug(ctx, "unauthorized",
slog.F("internal", internalError.Internal()),
slog.F("input", internalError.Input()),
slog.Error(err),
)
}
return NotAuthorizedError{
Err: err,
}
}
// querier is a wrapper around the database store that performs authorization
// checks before returning data. All querier methods expect an authorization
// subject present in the context. If no subject is present, most methods will
// fail.
//
// Use WithAuthorizeContext to set the authorization subject in the context for
// the common user case.
type querier struct {
db database.Store
auth rbac.Authorizer
log slog.Logger
}
func New(db database.Store, authorizer rbac.Authorizer, logger slog.Logger) database.Store {
// If the underlying db store is already a querier, return it.
// Do not double wrap.
if _, ok := db.(*querier); ok {
return db
}
return &querier{
db: db,
auth: authorizer,
log: logger,
}
}
// authorizeContext is a helper function to authorize an action on an object.
func (q *querier) authorizeContext(ctx context.Context, action rbac.Action, object rbac.Objecter) error {
act, ok := ActorFromContext(ctx)
if !ok {
return NoActorError
}
err := q.auth.Authorize(ctx, act, action, object.RBACObject())
if err != nil {
return logNotAuthorizedError(ctx, q.log, err)
}
return nil
}
type authContextKey struct{}
// ActorFromContext returns the authorization subject from the context.
// All authentication flows should set the authorization subject in the context.
// If no actor is present, the function returns false.
func ActorFromContext(ctx context.Context) (rbac.Subject, bool) {
a, ok := ctx.Value(authContextKey{}).(rbac.Subject)
return a, ok
}
// AsSystem returns a context with a system actor. This is used for internal
// system operations that are not tied to any particular actor.
// When you use this function, be sure to add a //nolint comment
// explaining why it is necessary.
//
// We trust you have received the usual lecture from the local System
// Administrator. It usually boils down to these three things:
// #1) Respect the privacy of others.
// #2) Think before you type.
// #3) With great power comes great responsibility.
func AsSystem(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, rbac.Subject{
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Name: "system",
DisplayName: "System",
Site: []rbac.Permission{
{
ResourceType: rbac.ResourceWildcard.Type,
Action: rbac.WildcardSymbol,
},
},
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
},
)
}
var AsRemoveActor = rbac.Subject{
ID: "remove-actor",
}
// As returns a context with the given actor stored in the context.
// This is used for cases where the actor touching the database is not the
// actor stored in the context.
// When you use this function, be sure to add a //nolint comment
// explaining why it is necessary.
func As(ctx context.Context, actor rbac.Subject) context.Context {
if actor.Equal(AsRemoveActor) {
// AsRemoveActor is a special case that is used to indicate that the actor
// should be removed from the context.
return context.WithValue(ctx, authContextKey{}, nil)
}
return context.WithValue(ctx, authContextKey{}, actor)
}
//
// Generic functions used to implement the database.Store methods.
//
// insert runs an rbac.ActionCreate on the rbac object argument before
// running the insertFunc. The insertFunc is expected to return the object that
// was inserted.
func insert[
ObjectType any,
ArgumentType any,
Insert func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
object rbac.Objecter,
insertFunc Insert,
) Insert {
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Authorize the action
err = authorizer.Authorize(ctx, act, rbac.ActionCreate, object.RBACObject())
if err != nil {
return empty, logNotAuthorizedError(ctx, logger, err)
}
// Insert the database object
return insertFunc(ctx, arg)
}
}
func deleteQ[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Delete func(ctx context.Context, arg ArgumentType) error,
](
logger slog.Logger,
authorizer rbac.Authorizer,
fetchFunc Fetch,
deleteFunc Delete,
) Delete {
return fetchAndExec(logger, authorizer,
rbac.ActionDelete, fetchFunc, deleteFunc)
}
func updateWithReturn[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
UpdateQuery func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
fetchFunc Fetch,
updateQuery UpdateQuery,
) UpdateQuery {
return fetchAndQuery(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateQuery)
}
func update[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Exec func(ctx context.Context, arg ArgumentType) error,
](
logger slog.Logger,
authorizer rbac.Authorizer,
fetchFunc Fetch,
updateExec Exec,
) Exec {
return fetchAndExec(logger, authorizer, rbac.ActionUpdate, fetchFunc, updateExec)
}
// fetch is a generic function that wraps a database
// query function (returns an object and an error) with authorization. The
// returned function has the same arguments as the database function.
//
// The database query function will **ALWAYS** hit the database, even if the
// user cannot read the resource. This is because the resource details are
// required to run a proper authorization check.
func fetch[
ArgumentType any,
ObjectType rbac.Objecter,
DatabaseFunc func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
f DatabaseFunc,
) DatabaseFunc {
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Fetch the database object
object, err := f(ctx, arg)
if err != nil {
return empty, xerrors.Errorf("fetch object: %w", err)
}
// Authorize the action
err = authorizer.Authorize(ctx, act, rbac.ActionRead, object.RBACObject())
if err != nil {
return empty, logNotAuthorizedError(ctx, logger, err)
}
return object, nil
}
}
// fetchAndExec uses fetchAndQuery but only returns the error. The naming comes
// from SQL 'exec' functions which only return an error.
// See fetchAndQuery for more information.
func fetchAndExec[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Exec func(ctx context.Context, arg ArgumentType) error,
](
logger slog.Logger,
authorizer rbac.Authorizer,
action rbac.Action,
fetchFunc Fetch,
execFunc Exec,
) Exec {
f := fetchAndQuery(logger, authorizer, action, fetchFunc, func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
return empty, execFunc(ctx, arg)
})
return func(ctx context.Context, arg ArgumentType) error {
_, err := f(ctx, arg)
return err
}
}
// fetchAndQuery is a generic function that wraps a database fetch and query.
// A query has potential side effects in the database (update, delete, etc).
// The fetch is used to know which rbac object the action should be asserted on
// **before** the query runs. The returns from the fetch are only used to
// assert rbac. The final return of this function comes from the Query function.
func fetchAndQuery[
ObjectType rbac.Objecter,
ArgumentType any,
Fetch func(ctx context.Context, arg ArgumentType) (ObjectType, error),
Query func(ctx context.Context, arg ArgumentType) (ObjectType, error),
](
logger slog.Logger,
authorizer rbac.Authorizer,
action rbac.Action,
fetchFunc Fetch,
queryFunc Query,
) Query {
return func(ctx context.Context, arg ArgumentType) (empty ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Fetch the database object
object, err := fetchFunc(ctx, arg)
if err != nil {
return empty, xerrors.Errorf("fetch object: %w", err)
}
// Authorize the action
err = authorizer.Authorize(ctx, act, action, object.RBACObject())
if err != nil {
return empty, logNotAuthorizedError(ctx, logger, err)
}
return queryFunc(ctx, arg)
}
}
// fetchWithPostFilter is like fetch, but works with lists of objects.
// SQL filters are much more optimal.
func fetchWithPostFilter[
ArgumentType any,
ObjectType rbac.Objecter,
DatabaseFunc func(ctx context.Context, arg ArgumentType) ([]ObjectType, error),
](
authorizer rbac.Authorizer,
f DatabaseFunc,
) DatabaseFunc {
return func(ctx context.Context, arg ArgumentType) (empty []ObjectType, err error) {
// Fetch the rbac subject
act, ok := ActorFromContext(ctx)
if !ok {
return empty, NoActorError
}
// Fetch the database object
objects, err := f(ctx, arg)
if err != nil {
return nil, xerrors.Errorf("fetch object: %w", err)
}
// Authorize the action
return rbac.Filter(ctx, authorizer, act, rbac.ActionRead, objects)
}
}
// prepareSQLFilter is a helper function that prepares a SQL filter using the
// given authorization context.
func prepareSQLFilter(ctx context.Context, authorizer rbac.Authorizer, action rbac.Action, resourceType string) (rbac.PreparedAuthorized, error) {
act, ok := ActorFromContext(ctx)
if !ok {
return nil, NoActorError
}
return authorizer.Prepare(ctx, act, action, resourceType)
}

View File

@ -0,0 +1,151 @@
package dbauthz_test
import (
"context"
"reflect"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/rbac"
)
func TestAsNoActor(t *testing.T) {
t.Parallel()
t.Run("AsRemoveActor", func(t *testing.T) {
t.Parallel()
_, ok := dbauthz.ActorFromContext(context.Background())
require.False(t, ok, "no actor should be present")
})
t.Run("AsActor", func(t *testing.T) {
t.Parallel()
ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject())
_, ok := dbauthz.ActorFromContext(ctx)
require.True(t, ok, "actor present")
})
t.Run("DeleteActor", func(t *testing.T) {
t.Parallel()
// First set an actor
ctx := dbauthz.As(context.Background(), coderdtest.RandomRBACSubject())
_, ok := dbauthz.ActorFromContext(ctx)
require.True(t, ok, "actor present")
// Delete the actor
ctx = dbauthz.As(ctx, dbauthz.AsRemoveActor)
_, ok = dbauthz.ActorFromContext(ctx)
require.False(t, ok, "actor should be deleted")
})
}
func TestPing(t *testing.T) {
t.Parallel()
q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{}, slog.Make())
_, err := q.Ping(context.Background())
require.NoError(t, err, "must not error")
}
// TestInTX is not perfect, just checks that it properly checks auth.
func TestInTX(t *testing.T) {
t.Parallel()
db := dbfake.New()
q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: xerrors.New("custom error")},
}, slog.Make())
actor := rbac.Subject{
ID: uuid.NewString(),
Roles: rbac.RoleNames{rbac.RoleOwner()},
Groups: []string{},
Scope: rbac.ScopeAll,
}
w := dbgen.Workspace(t, db, database.Workspace{})
ctx := dbauthz.As(context.Background(), actor)
err := q.InTx(func(tx database.Store) error {
// The inner tx should use the parent's authz
_, err := tx.GetWorkspaceByID(ctx, w.ID)
return err
}, nil)
require.Error(t, err, "must error")
require.ErrorAs(t, err, &dbauthz.NotAuthorizedError{}, "must be an authorized error")
}
// TestNew should not double wrap a querier.
func TestNew(t *testing.T) {
t.Parallel()
var (
db = dbfake.New()
exp = dbgen.Workspace(t, db, database.Workspace{})
rec = &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
}
subj = rbac.Subject{}
ctx = dbauthz.As(context.Background(), rbac.Subject{})
)
// Double wrap should not cause an actual double wrap. So only 1 rbac call
// should be made.
az := dbauthz.New(db, rec, slog.Make())
az = dbauthz.New(az, rec, slog.Make())
w, err := az.GetWorkspaceByID(ctx, exp.ID)
require.NoError(t, err, "must not error")
require.Equal(t, exp, w, "must be equal")
rec.AssertActor(t, subj, rec.Pair(rbac.ActionRead, exp))
require.NoError(t, rec.AllAsserted(), "should only be 1 rbac call")
}
// TestDBAuthzRecursive is a simple test to search for infinite recursion
// bugs. It isn't perfect, and only catches a subset of the possible bugs
// as only the first db call will be made. But it is better than nothing.
func TestDBAuthzRecursive(t *testing.T) {
t.Parallel()
q := dbauthz.New(dbfake.New(), &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
}, slog.Make())
actor := rbac.Subject{
ID: uuid.NewString(),
Roles: rbac.RoleNames{rbac.RoleOwner()},
Groups: []string{},
Scope: rbac.ScopeAll,
}
for i := 0; i < reflect.TypeOf(q).NumMethod(); i++ {
var ins []reflect.Value
ctx := dbauthz.As(context.Background(), actor)
ins = append(ins, reflect.ValueOf(ctx))
method := reflect.TypeOf(q).Method(i)
for i := 2; i < method.Type.NumIn(); i++ {
ins = append(ins, reflect.New(method.Type.In(i)).Elem())
}
if method.Name == "InTx" || method.Name == "Ping" {
continue
}
// Log the name of the last method, so if there is a panic, it is
// easy to know which method failed.
// t.Log(method.Name)
// Call the function. Any infinite recursion will stack overflow.
reflect.ValueOf(q).Method(i).Call(ins)
}
}
func must[T any](value T, err error) T {
if err != nil {
panic(err)
}
return value
}

View File

@ -0,0 +1,17 @@
// Package dbauthz provides an authorization layer on top of the database. This
// package exposes an interface that is currently a 1:1 mapping with
// database.Store.
//
// The same cultural rules apply to this package as they do to database.Store.
// Meaning that each method implemented should keep the number of database
// queries as close to 1 as possible. Each method should do 1 thing, with no
// unexpected side effects (eg: updating multiple tables in a single method).
//
// Do not implement business logic in this package. Only authorization related
// logic should be implemented here. In most cases, this should only be a call to
// the rbac authorizer.
//
// When a new database method is added to database.Store, it should be added to
// this package as well. The unit test "Accounting" will ensure all methods are
// tested. See other unit tests for examples on how to write these.
package dbauthz

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,377 @@
package dbauthz_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"testing"
"golang.org/x/xerrors"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"cdr.dev/slog"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
"github.com/coder/coder/coderd/util/slice"
)
var (
skipMethods = map[string]string{
"InTx": "Not relevant",
"Ping": "Not relevant",
}
)
// TestMethodTestSuite runs MethodTestSuite.
// In order for 'go test' to run this suite, we need to create
// a normal test function and pass our suite to suite.Run
// nolint: paralleltest
func TestMethodTestSuite(t *testing.T) {
suite.Run(t, new(MethodTestSuite))
}
// MethodTestSuite runs all methods tests for querier. We use
// a test suite so we can account for all functions tested on the querier.
// We can then assert all methods were tested and asserted for proper RBAC
// checks. This forces RBAC checks to be written for all methods.
// Additionally, the way unit tests are written allows for easily executing
// a single test for debugging.
type MethodTestSuite struct {
suite.Suite
// methodAccounting counts all methods called by a 'RunMethodTest'
methodAccounting map[string]int
}
// SetupSuite sets up the suite by creating a map of all methods on querier
// and setting their count to 0.
func (s *MethodTestSuite) SetupSuite() {
az := dbauthz.New(nil, nil, slog.Make())
// Take the underlying type of the interface.
azt := reflect.TypeOf(az).Elem()
s.methodAccounting = make(map[string]int)
for i := 0; i < azt.NumMethod(); i++ {
method := azt.Method(i)
if _, ok := skipMethods[method.Name]; ok {
// We can't use s.T().Skip as this will skip the entire suite.
s.T().Logf("Skipping method %q: %s", method.Name, skipMethods[method.Name])
continue
}
s.methodAccounting[method.Name] = 0
}
}
// TearDownSuite asserts that all methods were called at least once.
func (s *MethodTestSuite) TearDownSuite() {
s.Run("Accounting", func() {
t := s.T()
notCalled := []string{}
for m, c := range s.methodAccounting {
if c <= 0 {
notCalled = append(notCalled, m)
}
}
sort.Strings(notCalled)
for _, m := range notCalled {
t.Errorf("Method never called: %q", m)
}
})
}
// Subtest is a helper function that returns a function that can be passed to
// s.Run(). This function will run the test case for the method that is being
// tested. The check parameter is used to assert the results of the method.
// If the caller does not use the `check` parameter, the test will fail.
func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expects)) func() {
return func() {
t := s.T()
testName := s.T().Name()
names := strings.Split(testName, "/")
methodName := names[len(names)-1]
s.methodAccounting[methodName]++
db := dbfake.New()
fakeAuthorizer := &coderdtest.FakeAuthorizer{
AlwaysReturn: nil,
}
rec := &coderdtest.RecordingAuthorizer{
Wrapped: fakeAuthorizer,
}
az := dbauthz.New(db, rec, slog.Make())
actor := rbac.Subject{
ID: uuid.NewString(),
Roles: rbac.RoleNames{rbac.RoleOwner()},
Groups: []string{},
Scope: rbac.ScopeAll,
}
ctx := dbauthz.As(context.Background(), actor)
var testCase expects
testCaseF(db, &testCase)
// Check the developer added assertions. If there are no assertions,
// an empty list should be passed.
s.Require().False(testCase.assertions == nil, "rbac assertions not set, use the 'check' parameter")
// Find the method with the name of the test.
var callMethod func(ctx context.Context) ([]reflect.Value, error)
azt := reflect.TypeOf(az)
MethodLoop:
for i := 0; i < azt.NumMethod(); i++ {
method := azt.Method(i)
if method.Name == methodName {
methodF := reflect.ValueOf(az).Method(i)
callMethod = func(ctx context.Context) ([]reflect.Value, error) {
resp := methodF.Call(append([]reflect.Value{reflect.ValueOf(ctx)}, testCase.inputs...))
return splitResp(t, resp)
}
break MethodLoop
}
}
require.NotNil(t, callMethod, "method %q does not exist", methodName)
if len(testCase.assertions) > 0 {
// Only run these tests if we know the underlying call makes
// rbac assertions.
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
}
if len(testCase.assertions) > 0 ||
slice.Contains([]string{
"GetAuthorizedWorkspaces",
"GetAuthorizedTemplates",
}, methodName) {
// Some methods do not make RBAC assertions because they use
// SQL. We still want to test that they return an error if the
// actor is not set.
s.NoActorErrorTest(callMethod)
}
// Always run
s.Run("Success", func() {
rec.Reset()
fakeAuthorizer.AlwaysReturn = nil
outputs, err := callMethod(ctx)
s.NoError(err, "method %q returned an error", methodName)
// Some tests may not care about the outputs, so we only assert if
// they are provided.
if testCase.outputs != nil {
// Assert the required outputs
s.Equal(len(testCase.outputs), len(outputs), "method %q returned unexpected number of outputs", methodName)
for i := range outputs {
a, b := testCase.outputs[i].Interface(), outputs[i].Interface()
if reflect.TypeOf(a).Kind() == reflect.Slice || reflect.TypeOf(a).Kind() == reflect.Array {
// Order does not matter
s.ElementsMatch(a, b, "method %q returned unexpected output %d", methodName, i)
} else {
s.Equal(a, b, "method %q returned unexpected output %d", methodName, i)
}
}
}
var pairs []coderdtest.ActionObjectPair
for _, assrt := range testCase.assertions {
for _, action := range assrt.Actions {
pairs = append(pairs, coderdtest.ActionObjectPair{
Action: action,
Object: assrt.Object,
})
}
}
rec.AssertActor(s.T(), actor, pairs...)
s.NoError(rec.AllAsserted(), "all rbac calls must be asserted")
})
}
}
func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("AsRemoveActor", func() {
// Call without any actor
_, err := callMethod(context.Background())
s.ErrorIs(err, dbauthz.NoActorError, "method should return NoActorError error when no actor is provided")
})
}
// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
// Asserts that the error returned is a NotAuthorizedError.
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("NotAuthorized", func() {
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
// If we have assertions, that means the method should FAIL
// if RBAC will disallow the request. The returned error should
// be expected to be a NotAuthorizedError.
resp, err := callMethod(ctx)
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
// any case where the error is nil and the response is an empty slice.
if err != nil || !hasEmptySliceResponse(resp) {
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
s.Errorf(err, "method should an error with disallow authz")
s.ErrorIsf(err, sql.ErrNoRows, "error should match sql.ErrNoRows")
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
}
})
}
func hasEmptySliceResponse(values []reflect.Value) bool {
for _, r := range values {
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
if r.Len() == 0 {
return true
}
}
}
return false
}
func splitResp(t *testing.T, values []reflect.Value) ([]reflect.Value, error) {
outputs := []reflect.Value{}
for _, r := range values {
if r.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
if r.IsNil() {
// Error is found, but it's nil!
return outputs, nil
}
err, ok := r.Interface().(error)
if !ok {
t.Fatal("error is not an error?!")
}
return outputs, err
}
outputs = append(outputs, r)
} //nolint: unreachable
t.Fatal("no expected error value found in responses (error can be nil)")
return nil, nil // unreachable, required to compile
}
// expects is used to build a test case for a method.
// It includes the expected inputs, rbac assertions, and expected outputs.
type expects struct {
inputs []reflect.Value
assertions []AssertRBAC
// outputs is optional. Can assert non-error return values.
outputs []reflect.Value
}
// Asserts is required. Asserts the RBAC authorize calls that should be made.
// If no RBAC calls are expected, pass an empty list: 'm.Asserts()'
func (m *expects) Asserts(pairs ...any) *expects {
m.assertions = asserts(pairs...)
return m
}
// Args is required. The arguments to be provided to the method.
// If there are no arguments, pass an empty list: 'm.Args()'
// The first context argument should not be included, as the test suite
// will provide it.
func (m *expects) Args(args ...any) *expects {
m.inputs = values(args...)
return m
}
// Returns is optional. If it is never called, it will not be asserted.
func (m *expects) Returns(rets ...any) *expects {
m.outputs = values(rets...)
return m
}
// AssertRBAC contains the object and actions to be asserted.
type AssertRBAC struct {
Object rbac.Object
Actions []rbac.Action
}
// values is a convenience method for creating []reflect.Value.
//
// values(workspace, template, ...)
//
// is equivalent to
//
// []reflect.Value{
// reflect.ValueOf(workspace),
// reflect.ValueOf(template),
// ...
// }
func values(ins ...any) []reflect.Value {
out := make([]reflect.Value, 0)
for _, input := range ins {
input := input
out = append(out, reflect.ValueOf(input))
}
return out
}
// asserts is a convenience method for creating AssertRBACs.
//
// The number of inputs must be an even number.
// asserts() will panic if this is not the case.
//
// Even-numbered inputs are the objects, and odd-numbered inputs are the actions.
// Objects must implement rbac.Objecter.
// Inputs can be a single rbac.Action, or a slice of rbac.Action.
//
// asserts(workspace, rbac.ActionRead, template, slice(rbac.ActionRead, rbac.ActionWrite), ...)
//
// is equivalent to
//
// []AssertRBAC{
// {Object: workspace, Actions: []rbac.Action{rbac.ActionRead}},
// {Object: template, Actions: []rbac.Action{rbac.ActionRead, rbac.ActionWrite)}},
// ...
// }
func asserts(inputs ...any) []AssertRBAC {
if len(inputs)%2 != 0 {
panic(fmt.Sprintf("Must be an even length number of args, found %d", len(inputs)))
}
out := make([]AssertRBAC, 0)
for i := 0; i < len(inputs); i += 2 {
obj, ok := inputs[i].(rbac.Objecter)
if !ok {
panic(fmt.Sprintf("object type '%T' does not implement rbac.Objecter", inputs[i]))
}
rbacObj := obj.RBACObject()
var actions []rbac.Action
actions, ok = inputs[i+1].([]rbac.Action)
if !ok {
action, ok := inputs[i+1].(rbac.Action)
if !ok {
// Could be the string type.
actionAsString, ok := inputs[i+1].(string)
if !ok {
panic(fmt.Sprintf("action '%q' not a supported action", actionAsString))
}
action = rbac.Action(actionAsString)
}
actions = []rbac.Action{action}
}
out = append(out, AssertRBAC{
Object: rbacObj,
Actions: actions,
})
}
return out
}
type emptyPreparedAuthorized struct{}
func (emptyPreparedAuthorized) Authorize(_ context.Context, _ rbac.Object) error { return nil }
func (emptyPreparedAuthorized) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
return "", nil
}

View File

@ -0,0 +1,194 @@
package dbauthz
import (
"context"
"time"
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
)
// TODO: All these system functions should have rbac objects created to allow
// only system roles to call them. No user roles should ever have the permission
// to these objects. Might need a negative permission on the `Owner` role to
// prevent owners.
func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) {
return q.db.UpdateUserLinkedID(ctx, arg)
}
func (q *querier) GetUserLinkByLinkedID(ctx context.Context, linkedID string) (database.UserLink, error) {
return q.db.GetUserLinkByLinkedID(ctx, linkedID)
}
func (q *querier) GetUserLinkByUserIDLoginType(ctx context.Context, arg database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
return q.db.GetUserLinkByUserIDLoginType(ctx, arg)
}
func (q *querier) GetLatestWorkspaceBuilds(ctx context.Context) ([]database.WorkspaceBuild, error) {
// This function is a system function until we implement a join for workspace builds.
// This is because we need to query for all related workspaces to the returned builds.
// This is a very inefficient method of fetching the latest workspace builds.
// We should just join the rbac properties.
return q.db.GetLatestWorkspaceBuilds(ctx)
}
// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent.
// This should only be used by a system user in that middleware.
func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) {
return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken)
}
func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) {
return q.db.GetActiveUserCount(ctx)
}
func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, error) {
return q.db.GetUnexpiredLicenses(ctx)
}
func (q *querier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (database.GetAuthorizationUserRolesRow, error) {
return q.db.GetAuthorizationUserRoles(ctx, userID)
}
func (q *querier) GetDERPMeshKey(ctx context.Context) (string, error) {
// TODO Implement authz check for system user.
return q.db.GetDERPMeshKey(ctx)
}
func (q *querier) InsertDERPMeshKey(ctx context.Context, value string) error {
// TODO Implement authz check for system user.
return q.db.InsertDERPMeshKey(ctx, value)
}
func (q *querier) InsertDeploymentID(ctx context.Context, value string) error {
// TODO Implement authz check for system user.
return q.db.InsertDeploymentID(ctx, value)
}
func (q *querier) InsertReplica(ctx context.Context, arg database.InsertReplicaParams) (database.Replica, error) {
// TODO Implement authz check for system user.
return q.db.InsertReplica(ctx, arg)
}
func (q *querier) UpdateReplica(ctx context.Context, arg database.UpdateReplicaParams) (database.Replica, error) {
// TODO Implement authz check for system user.
return q.db.UpdateReplica(ctx, arg)
}
func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error {
// TODO Implement authz check for system user.
return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt)
}
func (q *querier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]database.Replica, error) {
// TODO Implement authz check for system user.
return q.db.GetReplicasUpdatedAfter(ctx, updatedAt)
}
func (q *querier) GetUserCount(ctx context.Context) (int64, error) {
return q.db.GetUserCount(ctx)
}
func (q *querier) GetTemplates(ctx context.Context) ([]database.Template, error) {
// TODO Implement authz check for system user.
return q.db.GetTemplates(ctx)
}
// UpdateWorkspaceBuildCostByID is used by the provisioning system to update the cost of a workspace build.
func (q *querier) UpdateWorkspaceBuildCostByID(ctx context.Context, arg database.UpdateWorkspaceBuildCostByIDParams) (database.WorkspaceBuild, error) {
return q.db.UpdateWorkspaceBuildCostByID(ctx, arg)
}
func (q *querier) InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error {
return q.db.InsertOrUpdateLastUpdateCheck(ctx, value)
}
func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) {
return q.db.GetLastUpdateCheck(ctx)
}
// Telemetry related functions. These functions are system functions for returning
// telemetry data. Never called by a user.
func (q *querier) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceBuild, error) {
return q.db.GetWorkspaceBuildsCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceAgent, error) {
return q.db.GetWorkspaceAgentsCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceAppsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceApp, error) {
return q.db.GetWorkspaceAppsCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceResourcesCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResource, error) {
return q.db.GetWorkspaceResourcesCreatedAfter(ctx, createdAt)
}
func (q *querier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.WorkspaceResourceMetadatum, error) {
return q.db.GetWorkspaceResourceMetadataCreatedAfter(ctx, createdAt)
}
func (q *querier) DeleteOldAgentStats(ctx context.Context) error {
return q.db.DeleteOldAgentStats(ctx)
}
func (q *querier) GetParameterSchemasCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ParameterSchema, error) {
return q.db.GetParameterSchemasCreatedAfter(ctx, createdAt)
}
func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]database.ProvisionerJob, error) {
return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt)
}
// Provisionerd server functions
func (q *querier) InsertWorkspaceAgent(ctx context.Context, arg database.InsertWorkspaceAgentParams) (database.WorkspaceAgent, error) {
return q.db.InsertWorkspaceAgent(ctx, arg)
}
func (q *querier) InsertWorkspaceApp(ctx context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) {
return q.db.InsertWorkspaceApp(ctx, arg)
}
func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg database.InsertWorkspaceResourceMetadataParams) ([]database.WorkspaceResourceMetadatum, error) {
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
}
func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.AcquireProvisionerJobParams) (database.ProvisionerJob, error) {
return q.db.AcquireProvisionerJob(ctx, arg)
}
func (q *querier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg database.UpdateProvisionerJobWithCompleteByIDParams) error {
return q.db.UpdateProvisionerJobWithCompleteByID(ctx, arg)
}
func (q *querier) UpdateProvisionerJobByID(ctx context.Context, arg database.UpdateProvisionerJobByIDParams) error {
return q.db.UpdateProvisionerJobByID(ctx, arg)
}
func (q *querier) InsertProvisionerJob(ctx context.Context, arg database.InsertProvisionerJobParams) (database.ProvisionerJob, error) {
return q.db.InsertProvisionerJob(ctx, arg)
}
func (q *querier) InsertProvisionerJobLogs(ctx context.Context, arg database.InsertProvisionerJobLogsParams) ([]database.ProvisionerJobLog, error) {
return q.db.InsertProvisionerJobLogs(ctx, arg)
}
func (q *querier) InsertProvisionerDaemon(ctx context.Context, arg database.InsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) {
return q.db.InsertProvisionerDaemon(ctx, arg)
}
func (q *querier) InsertTemplateVersionParameter(ctx context.Context, arg database.InsertTemplateVersionParameterParams) (database.TemplateVersionParameter, error) {
return q.db.InsertTemplateVersionParameter(ctx, arg)
}
func (q *querier) InsertWorkspaceResource(ctx context.Context, arg database.InsertWorkspaceResourceParams) (database.WorkspaceResource, error) {
return q.db.InsertWorkspaceResource(ctx, arg)
}
func (q *querier) InsertParameterSchema(ctx context.Context, arg database.InsertParameterSchemaParams) (database.ParameterSchema, error) {
return q.db.InsertParameterSchema(ctx, arg)
}

View File

@ -0,0 +1,219 @@
package dbauthz_test
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbgen"
)
func (s *MethodTestSuite) TestSystemFunctions() {
s.Run("UpdateUserLinkedID", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
l := dbgen.UserLink(s.T(), db, database.UserLink{UserID: u.ID})
check.Args(database.UpdateUserLinkedIDParams{
UserID: u.ID,
LinkedID: l.LinkedID,
LoginType: database.LoginTypeGithub,
}).Asserts().Returns(l)
}))
s.Run("GetUserLinkByLinkedID", s.Subtest(func(db database.Store, check *expects) {
l := dbgen.UserLink(s.T(), db, database.UserLink{})
check.Args(l.LinkedID).Asserts().Returns(l)
}))
s.Run("GetUserLinkByUserIDLoginType", s.Subtest(func(db database.Store, check *expects) {
l := dbgen.UserLink(s.T(), db, database.UserLink{})
check.Args(database.GetUserLinkByUserIDLoginTypeParams{
UserID: l.UserID,
LoginType: l.LoginType,
}).Asserts().Returns(l)
}))
s.Run("GetLatestWorkspaceBuilds", s.Subtest(func(db database.Store, check *expects) {
dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
check.Args().Asserts()
}))
s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) {
agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{})
check.Args(agt.AuthToken).Asserts().Returns(agt)
}))
s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts().Returns(int64(0))
}))
s.Run("GetUnexpiredLicenses", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts()
}))
s.Run("GetAuthorizationUserRoles", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
check.Args(u.ID).Asserts()
}))
s.Run("GetDERPMeshKey", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts()
}))
s.Run("InsertDERPMeshKey", s.Subtest(func(db database.Store, check *expects) {
check.Args("value").Asserts().Returns()
}))
s.Run("InsertDeploymentID", s.Subtest(func(db database.Store, check *expects) {
check.Args("value").Asserts().Returns()
}))
s.Run("InsertReplica", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertReplicaParams{
ID: uuid.New(),
}).Asserts()
}))
s.Run("UpdateReplica", s.Subtest(func(db database.Store, check *expects) {
replica, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New()})
require.NoError(s.T(), err)
check.Args(database.UpdateReplicaParams{
ID: replica.ID,
DatabaseLatency: 100,
}).Asserts()
}))
s.Run("DeleteReplicasUpdatedBefore", s.Subtest(func(db database.Store, check *expects) {
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()})
require.NoError(s.T(), err)
check.Args(time.Now().Add(time.Hour)).Asserts()
}))
s.Run("GetReplicasUpdatedAfter", s.Subtest(func(db database.Store, check *expects) {
_, err := db.InsertReplica(context.Background(), database.InsertReplicaParams{ID: uuid.New(), UpdatedAt: time.Now()})
require.NoError(s.T(), err)
check.Args(time.Now().Add(time.Hour * -1)).Asserts()
}))
s.Run("GetUserCount", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts().Returns(int64(0))
}))
s.Run("GetTemplates", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.Template(s.T(), db, database.Template{})
check.Args().Asserts()
}))
s.Run("UpdateWorkspaceBuildCostByID", s.Subtest(func(db database.Store, check *expects) {
b := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{})
o := b
o.DailyCost = 10
check.Args(database.UpdateWorkspaceBuildCostByIDParams{
ID: b.ID,
DailyCost: 10,
}).Asserts().Returns(o)
}))
s.Run("InsertOrUpdateLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) {
check.Args("value").Asserts()
}))
s.Run("GetLastUpdateCheck", s.Subtest(func(db database.Store, check *expects) {
err := db.InsertOrUpdateLastUpdateCheck(context.Background(), "value")
require.NoError(s.T(), err)
check.Args().Asserts()
}))
s.Run("GetWorkspaceBuildsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceAgentsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceAppsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceApp(s.T(), db, database.WorkspaceApp{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceResourcesCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetWorkspaceResourceMetadataCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.WorkspaceResourceMetadatums(s.T(), db, database.WorkspaceResourceMetadatum{})
check.Args(time.Now()).Asserts()
}))
s.Run("DeleteOldAgentStats", s.Subtest(func(db database.Store, check *expects) {
check.Args().Asserts()
}))
s.Run("GetParameterSchemasCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.ParameterSchema(s.T(), db, database.ParameterSchema{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts()
}))
s.Run("InsertWorkspaceAgent", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertWorkspaceAgentParams{
ID: uuid.New(),
}).Asserts()
}))
s.Run("InsertWorkspaceApp", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertWorkspaceAppParams{
ID: uuid.New(),
Health: database.WorkspaceAppHealthDisabled,
SharingLevel: database.AppSharingLevelOwner,
}).Asserts()
}))
s.Run("InsertWorkspaceResourceMetadata", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertWorkspaceResourceMetadataParams{
WorkspaceResourceID: uuid.New(),
}).Asserts()
}))
s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
StartedAt: sql.NullTime{Valid: false},
})
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}}).
Asserts()
}))
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{
ID: j.ID,
}).Asserts()
}))
s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
check.Args(database.UpdateProvisionerJobByIDParams{
ID: j.ID,
UpdatedAt: time.Now(),
}).Asserts()
}))
s.Run("InsertProvisionerJob", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertProvisionerJobParams{
ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
}).Asserts()
}))
s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
check.Args(database.InsertProvisionerJobLogsParams{
JobID: j.ID,
}).Asserts()
}))
s.Run("InsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertProvisionerDaemonParams{
ID: uuid.New(),
}).Asserts()
}))
s.Run("InsertTemplateVersionParameter", s.Subtest(func(db database.Store, check *expects) {
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{})
check.Args(database.InsertTemplateVersionParameterParams{
TemplateVersionID: v.ID,
}).Asserts()
}))
s.Run("InsertWorkspaceResource", s.Subtest(func(db database.Store, check *expects) {
r := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{})
check.Args(database.InsertWorkspaceResourceParams{
ID: r.ID,
Transition: database.WorkspaceTransitionStart,
}).Asserts()
}))
s.Run("InsertParameterSchema", s.Subtest(func(db database.Store, check *expects) {
check.Args(database.InsertParameterSchemaParams{
ID: uuid.New(),
DefaultSourceScheme: database.ParameterSourceSchemeNone,
DefaultDestinationScheme: database.ParameterDestinationSchemeNone,
ValidationTypeSystem: database.ParameterTypeSystemNone,
}).Asserts()
}))
}

View File

@ -614,6 +614,14 @@ func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params databas
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return -1, err
}
}
users := make([]database.User, 0, len(q.users))
for _, user := range q.users {
@ -892,6 +900,14 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
q.mutex.RLock()
defer q.mutex.RUnlock()
if prepared != nil {
// Call this to match the same function calls as the SQL implementation.
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return nil, err
}
}
workspaces := make([]database.Workspace, 0)
for _, workspace := range q.workspaces {
if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID {
@ -1230,6 +1246,23 @@ func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg databa
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceByWorkspaceAppID(_ context.Context, workspaceAppID uuid.UUID) (database.Workspace, error) {
if err := validateDatabaseType(workspaceAppID); err != nil {
return database.Workspace{}, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, workspaceApp := range q.workspaceApps {
workspaceApp := workspaceApp
if workspaceApp.ID == workspaceAppID {
return q.GetWorkspaceByAgentID(context.Background(), workspaceApp.AgentID)
}
}
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1646,6 +1679,14 @@ func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.G
q.mutex.RLock()
defer q.mutex.RUnlock()
// Call this to match the same function calls as the SQL implementation.
if prepared != nil {
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL())
if err != nil {
return nil, err
}
}
var templates []database.Template
for _, template := range q.templates {
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
@ -3819,6 +3860,18 @@ func (q *fakeQuerier) InsertLicense(
return l, nil
}
func (q *fakeQuerier) GetLicenseByID(_ context.Context, id int32) (database.License, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, license := range q.licenses {
if license.ID == id {
return license, nil
}
}
return database.License{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

View File

@ -66,7 +66,7 @@ func Template(t testing.TB, db database.Store, seed database.Template) database.
UserACL: seed.UserACL,
GroupACL: seed.GroupACL,
DisplayName: takeFirst(seed.DisplayName, namesgenerator.GetRandomName(1)),
AllowUserCancelWorkspaceJobs: takeFirst(seed.AllowUserCancelWorkspaceJobs, true),
AllowUserCancelWorkspaceJobs: seed.AllowUserCancelWorkspaceJobs,
})
require.NoError(t, err, "insert template")
return template
@ -369,11 +369,8 @@ func GitAuthLink(t testing.TB, db database.Store, orig database.GitAuthLink) dat
func TemplateVersion(t testing.TB, db database.Store, orig database.TemplateVersion) database.TemplateVersion {
version, err := db.InsertTemplateVersion(context.Background(), database.InsertTemplateVersionParams{
ID: takeFirst(orig.ID, uuid.New()),
TemplateID: uuid.NullUUID{
UUID: takeFirst(orig.TemplateID.UUID, uuid.New()),
Valid: takeFirst(orig.TemplateID.Valid, true),
},
ID: takeFirst(orig.ID, uuid.New()),
TemplateID: orig.TemplateID,
OrganizationID: takeFirst(orig.OrganizationID, uuid.New()),
CreatedAt: takeFirst(orig.CreatedAt, database.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, database.Now()),

View File

@ -68,7 +68,7 @@ func TestGenerator(t *testing.T) {
require.Equal(t, exp, must(db.GetWorkspaceAppsByAgentID(context.Background(), exp.AgentID))[0])
})
t.Run("WorkspaceResourceMetadatum", func(t *testing.T) {
t.Run("WorkspaceResourceMetadata", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
exp := dbgen.WorkspaceResourceMetadatums(t, db, database.WorkspaceResourceMetadatum{})

View File

@ -2,6 +2,7 @@ package database
import (
"sort"
"strconv"
"github.com/coder/coder/coderd/rbac"
)
@ -63,6 +64,11 @@ func (TemplateVersion) RBACObject(template Template) rbac.Object {
return template.RBACObject()
}
// RBACObjectNoTemplate is for orphaned template versions.
func (v TemplateVersion) RBACObjectNoTemplate() rbac.Object {
return rbac.ResourceTemplate.InOrg(v.OrganizationID)
}
func (g Group) RBACObject() rbac.Object {
return rbac.ResourceGroup.WithID(g.ID).
InOrg(g.OrganizationID)
@ -94,6 +100,13 @@ func (m OrganizationMember) RBACObject() rbac.Object {
InOrg(m.OrganizationID)
}
func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object {
// TODO: This feels incorrect as we are really returning a list of orgmembers.
// This return type should be refactored to return a list of orgmembers, not this
// special type.
return rbac.ResourceUser.WithID(m.UserID)
}
func (o Organization) RBACObject() rbac.Object {
return rbac.ResourceOrganization.
WithID(o.ID).
@ -118,11 +131,29 @@ func (u User) RBACObject() rbac.Object {
}
func (u User) UserDataRBACObject() rbac.Object {
return rbac.ResourceUser.WithID(u.ID).WithOwner(u.ID.String())
return rbac.ResourceUserData.WithID(u.ID).WithOwner(u.ID.String())
}
func (License) RBACObject() rbac.Object {
return rbac.ResourceLicense
func (u GetUsersRow) RBACObject() rbac.Object {
return rbac.ResourceUser.WithID(u.ID)
}
func (u GitSSHKey) RBACObject() rbac.Object {
return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String())
}
func (u GitAuthLink) RBACObject() rbac.Object {
// I assume UserData is ok?
return rbac.ResourceUserData.WithID(u.UserID).WithOwner(u.UserID.String())
}
func (u UserLink) RBACObject() rbac.Object {
// I assume UserData is ok?
return rbac.ResourceUserData.WithOwner(u.UserID.String()).WithID(u.UserID)
}
func (l License) RBACObject() rbac.Object {
return rbac.ResourceLicense.WithIDString(strconv.FormatInt(int64(l.ID), 10))
}
func ConvertUserRows(rows []GetUsersRow) []User {

View File

@ -56,6 +56,7 @@ type sqlcQuerier interface {
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error)
GetLicenseByID(ctx context.Context, id int32) (License, error)
GetLicenses(ctx context.Context) ([]License, error)
GetLogoURL(ctx context.Context) (string, error)
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
@ -121,6 +122,7 @@ type sqlcQuerier interface {
GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error)
GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error)
GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWorkspaceByOwnerIDAndNameParams) (Workspace, error)
GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error)
GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (WorkspaceResource, error)
GetWorkspaceResourceMetadataByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceResourceMetadatum, error)
GetWorkspaceResourceMetadataCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceResourceMetadatum, error)

View File

@ -1343,6 +1343,30 @@ func (q *sqlQuerier) DeleteLicense(ctx context.Context, id int32) (int32, error)
return id, err
}
const getLicenseByID = `-- name: GetLicenseByID :one
SELECT
id, uploaded_at, jwt, exp, uuid
FROM
licenses
WHERE
id = $1
LIMIT
1
`
func (q *sqlQuerier) GetLicenseByID(ctx context.Context, id int32) (License, error) {
row := q.db.QueryRowContext(ctx, getLicenseByID, id)
var i License
err := row.Scan(
&i.ID,
&i.UploadedAt,
&i.JWT,
&i.Exp,
&i.UUID,
)
return i, err
}
const getLicenses = `-- name: GetLicenses :many
SELECT id, uploaded_at, jwt, exp, uuid
FROM licenses
@ -6513,6 +6537,62 @@ func (q *sqlQuerier) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWo
return i, err
}
const getWorkspaceByWorkspaceAppID = `-- name: GetWorkspaceByWorkspaceAppID :one
SELECT
id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at
FROM
workspaces
WHERE
workspaces.id = (
SELECT
workspace_id
FROM
workspace_builds
WHERE
workspace_builds.job_id = (
SELECT
job_id
FROM
workspace_resources
WHERE
workspace_resources.id = (
SELECT
resource_id
FROM
workspace_agents
WHERE
workspace_agents.id = (
SELECT
agent_id
FROM
workspace_apps
WHERE
workspace_apps.id = $1
)
)
)
)
`
func (q *sqlQuerier) GetWorkspaceByWorkspaceAppID(ctx context.Context, workspaceAppID uuid.UUID) (Workspace, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceByWorkspaceAppID, workspaceAppID)
var i Workspace
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OwnerID,
&i.OrganizationID,
&i.TemplateID,
&i.Deleted,
&i.Name,
&i.AutostartSchedule,
&i.Ttl,
&i.LastUsedAt,
)
return i, err
}
const getWorkspaces = `-- name: GetWorkspaces :many
SELECT
workspaces.id, workspaces.created_at, workspaces.updated_at, workspaces.owner_id, workspaces.organization_id, workspaces.template_id, workspaces.deleted, workspaces.name, workspaces.autostart_schedule, workspaces.ttl, workspaces.last_used_at, COUNT(*) OVER () as count

View File

@ -14,6 +14,16 @@ SELECT *
FROM licenses
ORDER BY (id);
-- name: GetLicenseByID :one
SELECT
*
FROM
licenses
WHERE
id = $1
LIMIT
1;
-- name: GetUnexpiredLicenses :many
SELECT *
FROM licenses

View File

@ -8,6 +8,42 @@ WHERE
LIMIT
1;
-- name: GetWorkspaceByWorkspaceAppID :one
SELECT
*
FROM
workspaces
WHERE
workspaces.id = (
SELECT
workspace_id
FROM
workspace_builds
WHERE
workspace_builds.job_id = (
SELECT
job_id
FROM
workspace_resources
WHERE
workspace_resources.id = (
SELECT
resource_id
FROM
workspace_agents
WHERE
workspace_agents.id = (
SELECT
agent_id
FROM
workspace_apps
WHERE
workspace_apps.id = @workspace_app_id
)
)
)
);
-- name: GetWorkspaceByAgentID :one
SELECT
*