mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore: Rewrite rbac rego -> SQL clause (#5138)
* chore: Rewrite rbac rego -> SQL clause Previous code was challenging to read with edge cases - bug: OrgAdmin could not make new groups - Also refactor some function names
This commit is contained in:
@ -5,12 +5,16 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
)
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
const (
|
||||
authorizedQueryPlaceholder = "-- @authorize_filter"
|
||||
)
|
||||
|
||||
// customQuerier encompasses all non-generated queries.
|
||||
@ -23,10 +27,70 @@ type customQuerier interface {
|
||||
}
|
||||
|
||||
type templateQuerier interface {
|
||||
GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error)
|
||||
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
|
||||
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedTemplates(ctx context.Context, arg GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]Template, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(regosql.ConvertConfig{
|
||||
VariableConverter: regosql.TemplateConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getTemplatesWithFilter, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedTemplates :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.OrganizationID,
|
||||
arg.ExactName,
|
||||
pq.Array(arg.IDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Template
|
||||
for rows.Next() {
|
||||
var i Template
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.OrganizationID,
|
||||
&i.Deleted,
|
||||
&i.Name,
|
||||
&i.Provisioner,
|
||||
&i.ActiveVersionID,
|
||||
&i.Description,
|
||||
&i.DefaultTTL,
|
||||
&i.CreatedBy,
|
||||
&i.Icon,
|
||||
&i.UserACL,
|
||||
&i.GroupACL,
|
||||
&i.DisplayName,
|
||||
&i.AllowUserCancelWorkspaceJobs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
type TemplateUser struct {
|
||||
User
|
||||
Actions Actions `db:"actions"`
|
||||
@ -112,18 +176,27 @@ func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([
|
||||
}
|
||||
|
||||
type workspaceQuerier interface {
|
||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error)
|
||||
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error)
|
||||
}
|
||||
|
||||
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
|
||||
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
|
||||
// clause.
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]GetWorkspacesRow, error) {
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]GetWorkspacesRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
|
||||
// authorizedFilter between the end of the where clause and those statements.
|
||||
filter := strings.Replace(getWorkspaces, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
filtered, err := insertAuthorizedFilter(getWorkspaces, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filter)
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.Status,
|
||||
@ -172,12 +245,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||
}
|
||||
|
||||
type userQuerier interface {
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(rbac.ConfigWithoutACL())
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
|
||||
filtered, err := insertAuthorizedFilter(getFilteredUserCount, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return -1, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filtered)
|
||||
row := q.db.QueryRowContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.Search,
|
||||
@ -185,6 +267,14 @@ func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFiltered
|
||||
pq.Array(arg.RbacRole),
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
err = row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
}
|
||||
filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1)
|
||||
return filtered, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user