mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
chore: Rewrite rbac rego -> SQL clause (#5138)
* chore: Rewrite rbac rego -> SQL clause Previous code was challenging to read with edge cases - bug: OrgAdmin could not make new groups - Also refactor some function names
This commit is contained in:
@ -20,6 +20,13 @@ import (
|
||||
"github.com/coder/coder/coderd/util/slice"
|
||||
)
|
||||
|
||||
// FakeDatabase is helpful for knowing if the underlying db is an in memory fake
|
||||
// database. This is only in the databasefake package, so will only be used
|
||||
// by unit tests.
|
||||
type FakeDatabase interface {
|
||||
IsFakeDB()
|
||||
}
|
||||
|
||||
var errDuplicateKey = &pq.Error{
|
||||
Code: "23505",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
@ -117,6 +124,7 @@ type data struct {
|
||||
lastLicenseID int32
|
||||
}
|
||||
|
||||
func (fakeQuerier) IsFakeDB() {}
|
||||
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
|
||||
return 0, nil
|
||||
}
|
||||
@ -488,11 +496,20 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
users := append([]database.User{}, q.users...)
|
||||
users := make([]database.User, 0, len(q.users))
|
||||
|
||||
for _, user := range q.users {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
if params.Deleted {
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
@ -539,13 +556,6 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.
|
||||
users = usersFilteredByRole
|
||||
}
|
||||
|
||||
for _, user := range q.workspaces {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return int64(len(users)), nil
|
||||
}
|
||||
|
||||
@ -750,7 +760,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa
|
||||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) {
|
||||
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
@ -923,7 +933,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
|
||||
}
|
||||
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
|
||||
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
workspaces = append(workspaces, workspace)
|
||||
@ -1505,12 +1515,20 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd
|
||||
return database.Template{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
|
||||
func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
|
||||
return q.GetAuthorizedTemplates(ctx, arg, nil)
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var templates []database.Template
|
||||
for _, template := range q.templates {
|
||||
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if template.Deleted != arg.Deleted {
|
||||
continue
|
||||
}
|
||||
|
@ -74,6 +74,7 @@ func TestExactMethods(t *testing.T) {
|
||||
extraFakeMethods := map[string]string{
|
||||
// Example
|
||||
// "SortFakeLists": "Helper function used",
|
||||
"IsFakeDB": "Helper function used for unit testing",
|
||||
}
|
||||
|
||||
fake := reflect.TypeOf(databasefake.New())
|
||||
|
@ -5,12 +5,16 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
)
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
const (
|
||||
authorizedQueryPlaceholder = "-- @authorize_filter"
|
||||
)
|
||||
|
||||
// customQuerier encompasses all non-generated queries.
|
||||
@ -23,10 +27,70 @@ type customQuerier interface {
|
||||
}
|
||||
|
||||
type templateQuerier interface {
|
||||
GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error)
|
||||
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
|
||||
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{
|
||||
VariableConverter: regosql.TemplateConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.OrganizationID,
|
||||
arg.ExactName,
|
||||
pq.Array(arg.IDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OrganizationID,
|
||||
&i.Deleted,
|
||||
&i.Name,
|
||||
&i.Provisioner,
|
||||
&i.ActiveVersionID,
|
||||
&i.Description,
|
||||
&i.DefaultTTL,
|
||||
&i.CreatedBy,
|
||||
&i.Icon,
|
||||
&i.UserACL,
|
||||
&i.GroupACL,
|
||||
&i.DisplayName,
|
||||
&i.AllowUserCancelWorkspaceJobs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
type TemplateUser struct {
|
||||
User
|
||||
Actions Actions `db:"actions"`
|
||||
@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
|
||||
}
|
||||
|
||||
type workspaceQuerier interface {
|
||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
|
||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
|
||||
}
|
||||
|
||||
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
|
||||
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
|
||||
// clause.
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
|
||||
// authorizedFilter between the end of the where clause and those statements.
|
||||
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.Status,
|
||||
@ -172,12 +245,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||
}
|
||||
|
||||
type userQuerier interface {
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
|
||||
row := q.db.QueryRowContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.Search,
|
||||
@ -185,6 +267,14 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered
|
||||
pq.Array(arg.RbacRole),
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
err = row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
}
|
||||
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
|
||||
return filtered, nil
|
||||
}
|
||||
|
15
coderd/database/modelqueries_internal_test.go
Normal file
15
coderd/database/modelqueries_internal_test.go
Normal file
@ -0,0 +1,15 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAuthorizedQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
query := `SELECT true;`
|
||||
_, err := insertAuthorizedFilter(query, "")
|
||||
require.ErrorContains(t, err, "does not contain authorized replace string", "ensure replace string")
|
||||
}
|
@ -3197,6 +3197,8 @@ WHERE
|
||||
id = ANY($4)
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
|
||||
-- @authorize_filter
|
||||
ORDER BY (name, id) ASC
|
||||
`
|
||||
|
||||
|
@ -34,6 +34,8 @@ WHERE
|
||||
id = ANY(@ids)
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
|
||||
-- @authorize_filter
|
||||
ORDER BY (name, id) ASC
|
||||
;
|
||||
|
||||
|
Reference in New Issue
Block a user