chore: Rewrite rbac rego -> SQL clause (#5138)

* chore: Rewrite rbac rego -> SQL clause

Previous code was challenging to read with edge cases
- bug: OrgAdmin could not make new groups
- Also refactor some function names
This commit is contained in:
Steven Masley
2022-11-28 12:12:34 -06:00
committed by GitHub
parent d5ab4fdeb8
commit ab9298f382
39 changed files with 2080 additions and 828 deletions

View File

@ -20,6 +20,13 @@ import (
"github.com/coder/coder/coderd/util/slice"
)
// FakeDatabase is helpful for knowing if the underlying db is an in memory fake
// database. This is only in the databasefake package, so will only be used
// by unit tests.
type FakeDatabase interface {
IsFakeDB()
}
var errDuplicateKey = &pq.Error{
Code: "23505",
Message: "duplicate key value violates unique constraint",
@ -117,6 +124,7 @@ type data struct {
lastLicenseID int32
}
func (fakeQuerier) IsFakeDB() {}
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
return 0, nil
}
@ -488,11 +496,20 @@ func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.Get
return count, err
}
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
func (q *fakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
users := append([]database.User{}, q.users...)
users := make([]database.User, 0, len(q.users))
for _, user := range q.users {
// If the filter exists, ensure the object is authorized.
if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil {
continue
}
users = append(users, user)
}
if params.Deleted {
tmp := make([]database.User, 0, len(users))
@ -539,13 +556,6 @@ func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.
users = usersFilteredByRole
}
for _, user := range q.workspaces {
// If the filter exists, ensure the object is authorized.
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
continue
}
}
return int64(len(users)), nil
}
@ -750,7 +760,7 @@ func (q *fakeQuerier) GetWorkspaces(ctx context.Context, arg database.GetWorkspa
}
//nolint:gocyclo
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]database.GetWorkspacesRow, error) {
func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -923,7 +933,7 @@ func (q *fakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.
}
// If the filter exists, ensure the object is authorized.
if authorizedFilter != nil && !authorizedFilter.Eval(workspace.RBACObject()) {
if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil {
continue
}
workspaces = append(workspaces, workspace)
@ -1505,12 +1515,20 @@ func (q *fakeQuerier) UpdateTemplateMetaByID(_ context.Context, arg database.Upd
return database.Template{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetTemplatesWithFilter(_ context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
func (q *fakeQuerier) GetTemplatesWithFilter(ctx context.Context, arg database.GetTemplatesWithFilterParams) ([]database.Template, error) {
return q.GetAuthorizedTemplates(ctx, arg, nil)
}
func (q *fakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var templates []database.Template
for _, template := range q.templates {
if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil {
continue
}
if template.Deleted != arg.Deleted {
continue
}

View File

@ -74,6 +74,7 @@ func TestExactMethods(t *testing.T) {
extraFakeMethods := map[string]string{
// Example
// "SortFakeLists": "Helper function used",
"IsFakeDB": "Helper function used for unit testing",
}
fake := reflect.TypeOf(databasefake.New())

View File

@ -5,12 +5,16 @@ import (
"fmt"
"strings"
"github.com/google/uuid"
"github.com/lib/pq"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/rbac/regosql"
)
"github.com/google/uuid"
"golang.org/x/xerrors"
const (
authorizedQueryPlaceholder = "-- @authorize_filter"
)
// customQuerier encompasses all non-generated queries.
@ -23,10 +27,70 @@ type customQuerier interface {
}
type templateQuerier interface {
GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error)
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
}
func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) {
authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{
VariableConverter: regosql.TemplateConverter(),
})
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.Deleted,
arg.OrganizationID,
arg.ExactName,
pq.Array(arg.IDs),
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Template
for rows.Next() {
var i Template
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OrganizationID,
&i.Deleted,
&i.Name,
&i.Provisioner,
&i.ActiveVersionID,
&i.Description,
&i.DefaultTTL,
&i.CreatedBy,
&i.Icon,
&i.UserACL,
&i.GroupACL,
&i.DisplayName,
&i.AllowUserCancelWorkspaceJobs,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
type TemplateUser struct {
User
Actions Actions `db:"actions"`
@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
}
type workspaceQuerier interface {
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
}
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
// clause.
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
if err != nil {
return nil, xerrors.Errorf("compile authorized filter: %w", err)
}
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
// authorizedFilter between the end of the where clause and those statements.
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return nil, xerrors.Errorf("insert authorized filter: %w", err)
}
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
rows, err := q.db.QueryContext(ctx, query,
arg.Deleted,
arg.Status,
@ -172,12 +245,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
}
type userQuerier interface {
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
}
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
if err != nil {
return -1, xerrors.Errorf("compile authorized filter: %w", err)
}
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
if err != nil {
return -1, xerrors.Errorf("insert authorized filter: %w", err)
}
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
row := q.db.QueryRowContext(ctx, query,
arg.Deleted,
arg.Search,
@ -185,6 +267,14 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered
pq.Array(arg.RbacRole),
)
var count int64
err := row.Scan(&count)
err = row.Scan(&count)
return count, err
}
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
if !strings.Contains(query, authorizedQueryPlaceholder) {
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
}
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
return filtered, nil
}

View File

@ -0,0 +1,15 @@
package database
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsAuthorizedQuery(t *testing.T) {
t.Parallel()
query := `SELECT true;`
_, err := insertAuthorizedFilter(query, "")
require.ErrorContains(t, err, "does not contain authorized replace string", "ensure replace string")
}

View File

@ -3197,6 +3197,8 @@ WHERE
id = ANY($4)
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
-- @authorize_filter
ORDER BY (name, id) ASC
`

View File

@ -34,6 +34,8 @@ WHERE
id = ANY(@ids)
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedTemplates
-- @authorize_filter
ORDER BY (name, id) ASC
;