feat: add template RBAC/groups (#4235)

This commit is contained in:
Jon Ayers
2022-10-10 15:37:06 -05:00
committed by GitHub
parent 2687e3db49
commit 3120c94c22
122 changed files with 8088 additions and 1062 deletions

View File

@ -265,6 +265,7 @@ func auditSearchQuery(query string) (database.GetAuditLogsOffsetParams, []coders
Username: parser.String(searchParams, "", "username"),
Email: parser.String(searchParams, "", "email"),
}
return filter, parser.Errors
}
@ -296,6 +297,7 @@ func actionFromString(actionString string) string {
return actionString
case codersdk.AuditActionDelete:
return actionString
default:
}
return ""
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
"net/http"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
@ -18,7 +20,7 @@ import (
// This is faster than calling Authorize() on each object.
func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) {
roles := httpmw.UserAuthorization(r)
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objects)
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objects)
if err != nil {
// Log the error as Filter should not be erroring.
h.Logger.Error(r.Context(), "filter failed",
@ -63,7 +65,7 @@ func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objec
// }
func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
roles := httpmw.UserAuthorization(r)
err := h.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, object.RBACObject())
err := h.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, object.RBACObject())
if err != nil {
// Log the errors for debugging
internalError := new(rbac.UnauthorizedError)
@ -95,7 +97,7 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
// Note the authorization is only for the given action and object type.
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
roles := httpmw.UserAuthorization(r)
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objectType)
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare filter: %w", err)
}
@ -127,6 +129,28 @@ func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
)
response := make(codersdk.AuthorizationResponse)
// Prevent using too many resources by ID. This prevents database abuse
// from this endpoint. This also prevents misuse of this endpoint, as
// resource_id should be used for single objects, not for a list of them.
var (
idFetch int
maxFetch = 10
)
for _, v := range params.Checks {
if v.Object.ResourceID != "" {
idFetch++
}
}
if idFetch > maxFetch {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf(
"Endpoint only supports using \"resource_id\" field %d times, found %d usages. Remove %d objects with this field set.",
maxFetch, idFetch, idFetch-maxFetch,
),
})
return
}
for k, v := range params.Checks {
if v.Object.ResourceType == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -135,15 +159,60 @@ func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
return
}
if v.Object.OwnerID == "me" {
v.Object.OwnerID = auth.ID.String()
obj := rbac.Object{
Owner: v.Object.OwnerID,
OrgID: v.Object.OrganizationID,
Type: v.Object.ResourceType,
}
err := api.Authorizer.ByRoleName(r.Context(), auth.ID.String(), auth.Roles, auth.Scope.ToRBAC(), rbac.Action(v.Action),
rbac.Object{
Owner: v.Object.OwnerID,
OrgID: v.Object.OrganizationID,
Type: v.Object.ResourceType,
})
if obj.Owner == "me" {
obj.Owner = auth.ID.String()
}
// If a resource ID is specified, fetch that specific resource.
if v.Object.ResourceID != "" {
id, err := uuid.Parse(v.Object.ResourceID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Object %q id is not a valid uuid.", v.Object.ResourceID),
Validations: []codersdk.ValidationError{{Field: "resource_id", Detail: err.Error()}},
})
return
}
var dbObj rbac.Objecter
var dbErr error
// Only support referencing some resources by ID.
switch v.Object.ResourceType {
case rbac.ResourceWorkspaceExecution.Type:
wrkSpace, err := api.Database.GetWorkspaceByID(ctx, id)
if err == nil {
dbObj = wrkSpace.ExecutionRBAC()
}
dbErr = err
case rbac.ResourceWorkspace.Type:
dbObj, dbErr = api.Database.GetWorkspaceByID(ctx, id)
case rbac.ResourceTemplate.Type:
dbObj, dbErr = api.Database.GetTemplateByID(ctx, id)
case rbac.ResourceUser.Type:
dbObj, dbErr = api.Database.GetUserByID(ctx, id)
case rbac.ResourceGroup.Type:
dbObj, dbErr = api.Database.GetGroupByID(ctx, id)
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Object type %q does not support \"resource_id\" field.", v.Object.ResourceType),
Validations: []codersdk.ValidationError{{Field: "resource_type", Detail: err.Error()}},
})
return
}
if dbErr != nil {
// 404 or unauthorized is false
response[k] = false
continue
}
obj = dbObj.RBACObject()
}
err := api.Authorizer.ByRoleName(r.Context(), auth.ID.String(), auth.Roles, auth.Scope.ToRBAC(), auth.Groups, rbac.Action(v.Action), obj)
response[k] = err == nil
}

View File

@ -19,7 +19,9 @@ func TestCheckPermissions(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
t.Cleanup(cancel)
adminClient := coderdtest.New(t, nil)
adminClient := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
// Create adminClient, member, and org adminClient
adminUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID)
@ -29,12 +31,17 @@ func TestCheckPermissions(t *testing.T) {
orgAdminUser, err := orgAdminClient.User(ctx, codersdk.Me)
require.NoError(t, err)
version := coderdtest.CreateTemplateVersion(t, adminClient, adminUser.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJob(t, adminClient, version.ID)
template := coderdtest.CreateTemplate(t, adminClient, adminUser.OrganizationID, version.ID)
// With admin, member, and org admin
const (
readAllUsers = "read-all-users"
readOrgWorkspaces = "read-org-workspaces"
readMyself = "read-myself"
readOwnWorkspaces = "read-own-workspaces"
readAllUsers = "read-all-users"
readOrgWorkspaces = "read-org-workspaces"
readMyself = "read-myself"
readOwnWorkspaces = "read-own-workspaces"
updateSpecificTemplate = "update-specific-template"
)
params := map[string]codersdk.AuthorizationCheck{
readAllUsers: {
@ -64,6 +71,13 @@ func TestCheckPermissions(t *testing.T) {
},
Action: "read",
},
updateSpecificTemplate: {
Object: codersdk.AuthorizationObject{
ResourceType: rbac.ResourceTemplate.Type,
ResourceID: template.ID.String(),
},
Action: "update",
},
}
testCases := []struct {
@ -77,10 +91,11 @@ func TestCheckPermissions(t *testing.T) {
Client: adminClient,
UserID: adminUser.UserID,
Check: map[string]bool{
readAllUsers: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
readAllUsers: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
{
@ -88,10 +103,11 @@ func TestCheckPermissions(t *testing.T) {
Client: orgAdminClient,
UserID: orgAdminUser.ID,
Check: map[string]bool{
readAllUsers: false,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
readAllUsers: false,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
{
@ -99,10 +115,11 @@ func TestCheckPermissions(t *testing.T) {
Client: memberClient,
UserID: memberUser.ID,
Check: map[string]bool{
readAllUsers: false,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: false,
readAllUsers: false,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: false,
updateSpecificTemplate: false,
},
},
}

View File

@ -283,6 +283,7 @@ func New(options *Options) *API {
r.Get("/{hash}", api.fileByHash)
r.Post("/", api.postFile)
})
r.Route("/provisionerdaemons", func(r chi.Router) {
r.Use(
apiKeyMiddleware,

View File

@ -499,6 +499,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
type authCall struct {
SubjectID string
Roles []string
Groups []string
Scope rbac.Scope
Action rbac.Action
Object rbac.Object
@ -513,14 +514,15 @@ var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
// ByRoleNameSQL does not record the call. This matches the postgres behavior
// of not calling Authorize()
func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []string, _ rbac.Scope, _ rbac.Action, _ rbac.Object) error {
func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []string, _ rbac.Scope, _ []string, _ rbac.Action, _ rbac.Object) error {
return r.AlwaysReturn
}
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, groups []string, action rbac.Action, object rbac.Object) error {
r.Called = &authCall{
SubjectID: subjectID,
Roles: roleNames,
Groups: groups,
Scope: scope,
Action: action,
Object: object,
@ -528,7 +530,7 @@ func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro
return r.AlwaysReturn
}
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, groups []string, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
return &fakePreparedAuthorizer{
Original: r,
SubjectID: subjectID,
@ -536,6 +538,7 @@ func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID str
Scope: scope,
Action: action,
HardCodedSQLString: "true",
Groups: groups,
}, nil
}
@ -549,12 +552,13 @@ type fakePreparedAuthorizer struct {
Roles []string
Scope rbac.Scope
Action rbac.Action
Groups []string
HardCodedSQLString string
HardCodedRegoString string
}
func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error {
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object)
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object)
}
// Compile returns a compiled version of the authorizer that will work for
@ -564,7 +568,7 @@ func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
}
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
return f.Original.ByRoleNameSQL(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Action, object) == nil
return f.Original.ByRoleNameSQL(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object) == nil
}
func (f fakePreparedAuthorizer) RegoString() string {

View File

@ -9,7 +9,6 @@ import (
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"database/sql"
"encoding/base64"
"encoding/json"
"encoding/pem"
@ -21,7 +20,6 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strconv"
"strings"
"testing"
@ -49,8 +47,7 @@ import (
"github.com/coder/coder/coderd/autobuild/executor"
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
@ -139,26 +136,7 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
})
}
// This can be hotswapped for a live database instance.
db := databasefake.New()
pubsub := database.NewPubsubInMemory()
if os.Getenv("DB") != "" {
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
t.Cleanup(closePg)
sqlDB, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
t.Cleanup(func() {
_ = sqlDB.Close()
})
db = database.New(sqlDB)
pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL)
require.NoError(t, err)
t.Cleanup(func() {
_ = pubsub.Close()
})
}
db, pubsub := dbtestutil.NewDB(t)
ctx, cancelFunc := context.WithCancel(context.Background())
lifecycleExecutor := executor.New(
@ -399,6 +377,7 @@ func createAnotherUserRetry(t *testing.T, client *codersdk.Client, organizationI
// with the responses provided. It uses the "echo" provisioner for compatibility
// with testing.
func CreateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses) codersdk.TemplateVersion {
t.Helper()
data, err := echo.Tar(res)
require.NoError(t, err)
file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data)

View File

@ -1,62 +0,0 @@
package database
import (
"context"
"fmt"
"github.com/lib/pq"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac"
)
type customQuerier interface {
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error)
}
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
// clause.
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) {
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.DefaultConfig()))
rows, err := q.db.QueryContext(ctx, query,
arg.Deleted,
arg.OwnerID,
arg.OwnerUsername,
arg.TemplateName,
pq.Array(arg.TemplateIds),
arg.Name,
)
if err != nil {
return nil, xerrors.Errorf("get authorized workspaces: %w", err)
}
defer rows.Close()
var items []Workspace
for rows.Next() {
var i Workspace
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OwnerID,
&i.OrganizationID,
&i.TemplateID,
&i.Deleted,
&i.Name,
&i.AutostartSchedule,
&i.Ttl,
&i.LastUsedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -12,23 +12,30 @@ import (
"github.com/lib/pq"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/util/slice"
)
var errDuplicateKey = &pq.Error{
Code: "23505",
Message: "duplicate key value violates unique constraint",
}
// New returns an in-memory fake of the database.
func New() database.Store {
return &fakeQuerier{
mutex: &sync.RWMutex{},
data: &data{
apiKeys: make([]database.APIKey, 0),
agentStats: make([]database.AgentStat, 0),
organizationMembers: make([]database.OrganizationMember, 0),
organizations: make([]database.Organization, 0),
users: make([]database.User, 0),
apiKeys: make([]database.APIKey, 0),
agentStats: make([]database.AgentStat, 0),
organizationMembers: make([]database.OrganizationMember, 0),
organizations: make([]database.Organization, 0),
users: make([]database.User, 0),
groups: make([]database.Group, 0),
groupMembers: make([]database.GroupMember, 0),
auditLogs: make([]database.AuditLog, 0),
files: make([]database.File, 0),
gitSSHKey: make([]database.GitSSHKey, 0),
@ -84,6 +91,8 @@ type data struct {
auditLogs []database.AuditLog
files []database.File
gitSSHKey []database.GitSSHKey
groups []database.Group
groupMembers []database.GroupMember
parameterSchemas []database.ParameterSchema
parameterValues []database.ParameterValue
provisionerDaemons []database.ProvisionerDaemon
@ -518,6 +527,13 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
}
}
var groups []string
for _, member := range q.groupMembers {
if member.UserID == userID {
groups = append(groups, member.GroupID.String())
}
}
if user == nil {
return database.GetAuthorizationUserRolesRow{}, sql.ErrNoRows
}
@ -527,6 +543,7 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
Username: user.Username,
Status: user.Status,
Roles: roles,
Groups: groups,
}, nil
}
@ -1269,6 +1286,116 @@ func (q *fakeQuerier) GetTemplates(_ context.Context) ([]database.Template, erro
return templates, nil
}
func (q *fakeQuerier) UpdateTemplateUserACLByID(_ context.Context, id uuid.UUID, acl database.TemplateACL) error {
q.mutex.RLock()
defer q.mutex.RUnlock()
for i, t := range q.templates {
if t.ID == id {
t = t.SetUserACL(acl)
q.templates[i] = t
return nil
}
}
return sql.ErrNoRows
}
func (q *fakeQuerier) UpdateTemplateGroupACLByID(_ context.Context, id uuid.UUID, acl database.TemplateACL) error {
q.mutex.RLock()
defer q.mutex.RUnlock()
for i, t := range q.templates {
if t.ID == id {
t = t.SetGroupACL(acl)
q.templates[i] = t
return nil
}
}
return sql.ErrNoRows
}
func (q *fakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var template database.Template
for _, t := range q.templates {
if t.ID == id {
template = t
break
}
}
if template.ID == uuid.Nil {
return nil, sql.ErrNoRows
}
acl := template.UserACL()
users := make([]database.TemplateUser, 0, len(acl))
for k, v := range acl {
user, err := q.GetUserByID(context.Background(), uuid.MustParse(k))
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get user by ID: %w", err)
}
// We don't delete users from the map if they
// get deleted so just skip.
if xerrors.Is(err, sql.ErrNoRows) {
continue
}
if user.Deleted || user.Status == database.UserStatusSuspended {
continue
}
users = append(users, database.TemplateUser{
User: user,
Actions: v,
})
}
return users, nil
}
func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var template database.Template
for _, t := range q.templates {
if t.ID == id {
template = t
break
}
}
if template.ID == uuid.Nil {
return nil, sql.ErrNoRows
}
acl := template.GroupACL()
groups := make([]database.TemplateGroup, 0, len(acl))
for k, v := range acl {
group, err := q.GetGroupByID(context.Background(), uuid.MustParse(k))
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get group by ID: %w", err)
}
// We don't delete groups from the map if they
// get deleted so just skip.
if xerrors.Is(err, sql.ErrNoRows) {
continue
}
groups = append(groups, database.TemplateGroup{
Group: group,
Actions: v,
})
}
return groups, nil
}
func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1749,6 +1876,10 @@ func (q *fakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTempl
MinAutostartInterval: arg.MinAutostartInterval,
CreatedBy: arg.CreatedBy,
}
template = template.SetUserACL(database.TemplateACL{})
template = template.SetGroupACL(database.TemplateACL{
arg.OrganizationID.String(): []rbac.Action{rbac.ActionRead},
})
q.templates = append(q.templates, template)
return template, nil
}
@ -2299,7 +2430,7 @@ func (q *fakeQuerier) UpdateWorkspace(_ context.Context, arg database.UpdateWork
continue
}
if other.Name == arg.Name {
return database.Workspace{}, &pq.Error{Code: "23505", Message: "duplicate key value violates unique constraint"}
return database.Workspace{}, errDuplicateKey
}
}
@ -2437,6 +2568,52 @@ func (q *fakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitS
return sql.ErrNoRows
}
func (q *fakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for _, member := range q.groupMembers {
if member.GroupID == arg.GroupID &&
member.UserID == arg.UserID {
return errDuplicateKey
}
}
//nolint:gosimple
q.groupMembers = append(q.groupMembers, database.GroupMember{
GroupID: arg.GroupID,
UserID: arg.UserID,
})
return nil
}
func (q *fakeQuerier) DeleteGroupMember(_ context.Context, userID uuid.UUID) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for i, member := range q.groupMembers {
if member.UserID == userID {
q.groupMembers = append(q.groupMembers[:i], q.groupMembers[i+1:]...)
}
}
return nil
}
func (q *fakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
for i, group := range q.groups {
if group.ID == arg.ID {
group.Name = arg.Name
q.groups[i] = group
return group, nil
}
}
return database.Group{}, sql.ErrNoRows
}
func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error {
q.mutex.Lock()
defer q.mutex.Unlock()
@ -2714,3 +2891,137 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
return database.UserLink{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetGroupByID(_ context.Context, id uuid.UUID) (database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, group := range q.groups {
if group.ID == id {
return group, nil
}
}
return database.Group{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, group := range q.groups {
if group.OrganizationID == arg.OrganizationID &&
group.Name == arg.Name {
return group, nil
}
}
return database.Group{}, sql.ErrNoRows
}
func (q *fakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) {
return q.InsertGroup(ctx, database.InsertGroupParams{
ID: orgID,
Name: database.AllUsersGroup,
OrganizationID: orgID,
})
}
func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, group := range q.groups {
if group.OrganizationID.String() == arg.OrganizationID.String() &&
group.Name == arg.Name {
return database.Group{}, errDuplicateKey
}
}
//nolint:gosimple
group := database.Group{
ID: arg.ID,
Name: arg.Name,
OrganizationID: arg.OrganizationID,
}
q.groups = append(q.groups, group)
return group, nil
}
func (*fakeQuerier) GetUserGroups(_ context.Context, _ uuid.UUID) ([]database.Group, error) {
panic("not implemented")
}
func (q *fakeQuerier) GetGroupMembers(_ context.Context, groupID uuid.UUID) ([]database.User, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var members []database.GroupMember
for _, member := range q.groupMembers {
if member.GroupID == groupID {
members = append(members, member)
}
}
users := make([]database.User, 0, len(members))
for _, member := range members {
for _, user := range q.users {
if user.ID == member.UserID && user.Status == database.UserStatusActive && !user.Deleted {
users = append(users, user)
break
}
}
}
return users, nil
}
func (q *fakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationID uuid.UUID) ([]database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var groups []database.Group
for _, group := range q.groups {
// Omit the allUsers group.
if group.OrganizationID == organizationID && group.ID != organizationID {
groups = append(groups, group)
}
}
return groups, nil
}
func (q *fakeQuerier) GetAllOrganizationMembers(_ context.Context, organizationID uuid.UUID) ([]database.User, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var users []database.User
for _, member := range q.organizationMembers {
if member.OrganizationID == organizationID {
for _, user := range q.users {
if user.ID == member.UserID {
users = append(users, user)
}
}
}
}
return users, nil
}
func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for i, group := range q.groups {
if group.ID == id {
q.groups = append(q.groups[:i], q.groups[i+1:]...)
return nil
}
}
return sql.ErrNoRows
}

View File

@ -13,6 +13,7 @@ import (
"database/sql"
"errors"
"github.com/jmoiron/sqlx"
"golang.org/x/xerrors"
)
@ -32,24 +33,34 @@ type DBTX interface {
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
}
// New creates a new database store using a SQL database connection.
func New(sdb *sql.DB) Store {
dbx := sqlx.NewDb(sdb, "postgres")
return &sqlQuerier{
db: sdb,
sdb: sdb,
db: dbx,
sdb: dbx,
}
}
// queries encompasses both are sqlc generated
// queries and our custom queries.
type querier interface {
sqlcQuerier
customQuerier
}
type sqlQuerier struct {
sdb *sql.DB
sdb *sqlx.DB
db DBTX
}
// InTx performs database operations inside a transaction.
func (q *sqlQuerier) InTx(function func(Store) error) error {
if _, ok := q.db.(*sql.Tx); ok {
if _, ok := q.db.(*sqlx.Tx); ok {
// If the current inner "db" is already a transaction, we just reuse it.
// We do not need to handle commit/rollback as the outer tx will handle
// that.
@ -60,7 +71,7 @@ func (q *sqlQuerier) InTx(function func(Store) error) error {
return nil
}
transaction, err := q.sdb.Begin()
transaction, err := q.sdb.BeginTxx(context.Background(), nil)
if err != nil {
return xerrors.Errorf("begin transaction: %w", err)
}

View File

@ -0,0 +1,40 @@
package dbtestutil
import (
"context"
"database/sql"
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/postgres"
)
func NewDB(t *testing.T) (database.Store, database.Pubsub) {
t.Helper()
db := databasefake.New()
pubsub := database.NewPubsubInMemory()
if os.Getenv("DB") != "" {
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
t.Cleanup(closePg)
sqlDB, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
t.Cleanup(func() {
_ = sqlDB.Close()
})
db = database.New(sqlDB)
pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL)
require.NoError(t, err)
t.Cleanup(func() {
_ = pubsub.Close()
})
}
return db, pubsub
}

View File

@ -0,0 +1,26 @@
package database
import (
"database/sql/driver"
"encoding/json"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac"
)
type Actions []rbac.Action
func (a *Actions) Scan(src interface{}) error {
switch v := src.(type) {
case string:
return json.Unmarshal([]byte(v), &a)
case []byte:
return json.Unmarshal(v, &a)
}
return xerrors.Errorf("unexpected type %T", src)
}
func (a *Actions) Value() (driver.Value, error) {
return json.Marshal(a)
}

View File

@ -162,6 +162,17 @@ CREATE TABLE gitsshkeys (
public_key text NOT NULL
);
CREATE TABLE group_members (
user_id uuid NOT NULL,
group_id uuid NOT NULL
);
CREATE TABLE groups (
id uuid NOT NULL,
name text NOT NULL,
organization_id uuid NOT NULL
);
CREATE TABLE licenses (
id integer NOT NULL,
uploaded_at timestamp with time zone NOT NULL,
@ -295,7 +306,9 @@ CREATE TABLE templates (
max_ttl bigint DEFAULT '604800000000000'::bigint NOT NULL,
min_autostart_interval bigint DEFAULT '3600000000000'::bigint NOT NULL,
created_by uuid NOT NULL,
icon character varying(256) DEFAULT ''::character varying NOT NULL
icon character varying(256) DEFAULT ''::character varying NOT NULL,
user_acl jsonb DEFAULT '{}'::jsonb NOT NULL,
group_acl jsonb DEFAULT '{}'::jsonb NOT NULL
);
CREATE TABLE user_links (
@ -424,6 +437,15 @@ ALTER TABLE ONLY files
ALTER TABLE ONLY gitsshkeys
ADD CONSTRAINT gitsshkeys_pkey PRIMARY KEY (user_id);
ALTER TABLE ONLY group_members
ADD CONSTRAINT group_members_user_id_group_id_key UNIQUE (user_id, group_id);
ALTER TABLE ONLY groups
ADD CONSTRAINT groups_name_organization_id_key UNIQUE (name, organization_id);
ALTER TABLE ONLY groups
ADD CONSTRAINT groups_pkey PRIMARY KEY (id);
ALTER TABLE ONLY licenses
ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt);
@ -545,6 +567,15 @@ ALTER TABLE ONLY api_keys
ALTER TABLE ONLY gitsshkeys
ADD CONSTRAINT gitsshkeys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id);
ALTER TABLE ONLY group_members
ADD CONSTRAINT group_members_group_id_fkey FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE;
ALTER TABLE ONLY group_members
ADD CONSTRAINT group_members_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY groups
ADD CONSTRAINT groups_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY organization_members
ADD CONSTRAINT organization_members_organization_id_uuid_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;

View File

@ -42,7 +42,7 @@ SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}")
rm -f queries/*.go
# Fix struct/interface names.
gofmt -w -r 'Querier -> querier' -- *.go
gofmt -w -r 'Querier -> sqlcQuerier' -- *.go
gofmt -w -r 'Queries -> sqlQuerier' -- *.go
# Ensure correct imports exist. Modules must all be downloaded so we get correct

View File

@ -0,0 +1,8 @@
BEGIN;
DROP TABLE group_members;
DROP TABLE groups;
ALTER TABLE templates DROP COLUMN group_acl;
ALTER TABLE templates DROP COLUMN user_acl;
COMMIT;

View File

@ -0,0 +1,48 @@
BEGIN;
ALTER TABLE templates ADD COLUMN user_acl jsonb NOT NULL default '{}';
ALTER TABLE templates ADD COLUMN group_acl jsonb NOT NULL default '{}';
CREATE TABLE groups (
id uuid NOT NULL,
name text NOT NULL,
organization_id uuid NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
PRIMARY KEY(id),
UNIQUE(name, organization_id)
);
CREATE TABLE group_members (
user_id uuid NOT NULL,
group_id uuid NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE,
UNIQUE(user_id, group_id)
);
-- Insert a group for every organization (which should just be 1).
INSERT INTO groups (
id,
name,
organization_id
) SELECT
id, 'Everyone' as name, id
FROM
organizations;
-- Insert allUsers groups into every existing template to avoid breaking
-- existing deployments.
UPDATE
templates
SET
group_acl = (
SELECT
json_build_object(
organizations.id, array_to_json('{"read"}'::text[])
)
FROM
organizations
WHERE
templates.organization_id = organizations.id
);
COMMIT;

View File

@ -1,9 +1,65 @@
package database
import (
"encoding/json"
"fmt"
"github.com/coder/coder/coderd/rbac"
)
const AllUsersGroup = "Everyone"
// TemplateACL is a map of user_ids to permissions.
type TemplateACL map[string][]rbac.Action
func (t Template) UserACL() TemplateACL {
var acl TemplateACL
if len(t.userACL) == 0 {
return acl
}
err := json.Unmarshal(t.userACL, &acl)
if err != nil {
panic(fmt.Sprintf("failed to unmarshal template.userACL: %v", err.Error()))
}
return acl
}
func (t Template) GroupACL() TemplateACL {
var acl TemplateACL
if len(t.groupACL) == 0 {
return acl
}
err := json.Unmarshal(t.groupACL, &acl)
if err != nil {
panic(fmt.Sprintf("failed to unmarshal template.userACL: %v", err.Error()))
}
return acl
}
func (t Template) SetGroupACL(acl TemplateACL) Template {
raw, err := json.Marshal(acl)
if err != nil {
panic(fmt.Sprintf("marshal user acl: %v", err))
}
t.groupACL = raw
return t
}
func (t Template) SetUserACL(acl TemplateACL) Template {
raw, err := json.Marshal(acl)
if err != nil {
panic(fmt.Sprintf("marshal user acl: %v", err))
}
t.userACL = raw
return t
}
func (s APIKeyScope) ToRBAC() rbac.Scope {
switch s {
case APIKeyScopeAll:
@ -16,12 +72,19 @@ func (s APIKeyScope) ToRBAC() rbac.Scope {
}
func (t Template) RBACObject() rbac.Object {
return rbac.ResourceTemplate.InOrg(t.OrganizationID)
obj := rbac.ResourceTemplate
return obj.InOrg(t.OrganizationID).
WithACLUserList(t.UserACL()).
WithGroupACL(t.GroupACL())
}
func (t TemplateVersion) RBACObject() rbac.Object {
func (TemplateVersion) RBACObject(template Template) rbac.Object {
// Just use the parent template resource for controlling versions
return rbac.ResourceTemplate.InOrg(t.OrganizationID)
return template.RBACObject()
}
func (g Group) RBACObject() rbac.Object {
return rbac.ResourceGroup.InOrg(g.OrganizationID)
}
func (w Workspace) RBACObject() rbac.Object {

View File

@ -0,0 +1,208 @@
package database
import (
"context"
"encoding/json"
"fmt"
"github.com/lib/pq"
"github.com/coder/coder/coderd/rbac"
"github.com/google/uuid"
"golang.org/x/xerrors"
)
// customQuerier encompasses all non-generated queries.
// It provides a flexible way to write queries for cases
// where sqlc proves inadequate.
type customQuerier interface {
templateQuerier
workspaceQuerier
}
type templateQuerier interface {
UpdateTemplateUserACLByID(ctx context.Context, id uuid.UUID, acl TemplateACL) error
UpdateTemplateGroupACLByID(ctx context.Context, id uuid.UUID, acl TemplateACL) error
GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error)
GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error)
}
type TemplateUser struct {
User
Actions Actions `db:"actions"`
}
func (q *sqlQuerier) UpdateTemplateUserACLByID(ctx context.Context, id uuid.UUID, acl TemplateACL) error {
raw, err := json.Marshal(acl)
if err != nil {
return xerrors.Errorf("marshal user acl: %w", err)
}
const query = `
UPDATE
templates
SET
user_acl = $2
WHERE
id = $1`
_, err = q.db.ExecContext(ctx, query, id.String(), raw)
if err != nil {
return xerrors.Errorf("update user acl: %w", err)
}
return nil
}
func (q *sqlQuerier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]TemplateUser, error) {
const query = `
SELECT
perms.value as actions, users.*
FROM
users
JOIN
(
SELECT
*
FROM
jsonb_each_text(
(
SELECT
templates.user_acl
FROM
templates
WHERE
id = $1
)
)
) AS perms
ON
users.id::text = perms.key
WHERE
users.deleted = false
AND
users.status = 'active';
`
var tus []TemplateUser
err := q.db.SelectContext(ctx, &tus, query, id.String())
if err != nil {
return nil, xerrors.Errorf("select user actions: %w", err)
}
return tus, nil
}
type TemplateGroup struct {
Group
Actions Actions `db:"actions"`
}
func (q *sqlQuerier) UpdateTemplateGroupACLByID(ctx context.Context, id uuid.UUID, acl TemplateACL) error {
raw, err := json.Marshal(acl)
if err != nil {
return xerrors.Errorf("marshal user acl: %w", err)
}
const query = `
UPDATE
templates
SET
group_acl = $2
WHERE
id = $1`
_, err = q.db.ExecContext(ctx, query, id.String(), raw)
if err != nil {
return xerrors.Errorf("update user acl: %w", err)
}
return nil
}
func (q *sqlQuerier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]TemplateGroup, error) {
const query = `
SELECT
perms.value as actions, groups.*
FROM
groups
JOIN
(
SELECT
*
FROM
jsonb_each_text(
(
SELECT
templates.group_acl
FROM
templates
WHERE
id = $1
)
)
) AS perms
ON
groups.id::text = perms.key;
`
var tgs []TemplateGroup
err := q.db.SelectContext(ctx, &tgs, query, id.String())
if err != nil {
return nil, xerrors.Errorf("select group roles: %w", err)
}
return tgs, nil
}
type workspaceQuerier interface {
GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error)
}
// GetAuthorizedWorkspaces returns all workspaces that the user is authorized to access.
// This code is copied from `GetWorkspaces` and adds the authorized filter WHERE
// clause.
func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspacesParams, authorizedFilter rbac.AuthorizeFilter) ([]Workspace, error) {
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaces :many\n%s AND %s", getWorkspaces, authorizedFilter.SQLString(rbac.NoACLConfig()))
rows, err := q.db.QueryContext(ctx, query,
arg.Deleted,
arg.OwnerID,
arg.OwnerUsername,
arg.TemplateName,
pq.Array(arg.TemplateIds),
arg.Name,
)
if err != nil {
return nil, xerrors.Errorf("get authorized workspaces: %w", err)
}
defer rows.Close()
var items []Workspace
for rows.Next() {
var i Workspace
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.OwnerID,
&i.OrganizationID,
&i.TemplateID,
&i.Deleted,
&i.Name,
&i.AutostartSchedule,
&i.Ttl,
&i.LastUsedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/tabbed/pqtype"
)
@ -413,6 +414,17 @@ type GitSSHKey struct {
PublicKey string `db:"public_key" json:"public_key"`
}
type Group struct {
ID uuid.UUID `db:"id" json:"id"`
Name string `db:"name" json:"name"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
}
type GroupMember struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupID uuid.UUID `db:"group_id" json:"group_id"`
}
type License struct {
ID int32 `db:"id" json:"id"`
UploadedAt time.Time `db:"uploaded_at" json:"uploaded_at"`
@ -524,6 +536,8 @@ type Template struct {
MinAutostartInterval int64 `db:"min_autostart_interval" json:"min_autostart_interval"`
CreatedBy uuid.UUID `db:"created_by" json:"created_by"`
Icon string `db:"icon" json:"icon"`
userACL json.RawMessage `db:"user_acl" json:"user_acl"`
groupACL json.RawMessage `db:"group_acl" json:"group_acl"`
}
type TemplateVersion struct {
@ -546,7 +560,7 @@ type User struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Status UserStatus `db:"status" json:"status"`
RBACRoles []string `db:"rbac_roles" json:"rbac_roles"`
RBACRoles pq.StringArray `db:"rbac_roles" json:"rbac_roles"`
LoginType LoginType `db:"login_type" json:"login_type"`
AvatarURL sql.NullString `db:"avatar_url" json:"avatar_url"`
Deleted bool `db:"deleted" json:"deleted"`

View File

@ -11,7 +11,7 @@ import (
"github.com/google/uuid"
)
type querier interface {
type sqlcQuerier interface {
// Acquires the lock for a single job that isn't started, completed,
// canceled, and that matches an array of provisioner types.
//
@ -21,6 +21,8 @@ type querier interface {
AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error)
DeleteAPIKeyByID(ctx context.Context, id string) error
DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error
DeleteGroupByID(ctx context.Context, id uuid.UUID) error
DeleteGroupMember(ctx context.Context, userID uuid.UUID) error
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteOldAgentStats(ctx context.Context) error
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
@ -28,6 +30,7 @@ type querier interface {
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActiveUserCount(ctx context.Context) (int64, error)
GetAllOrganizationMembers(ctx context.Context, organizationID uuid.UUID) ([]User, error)
GetAuditLogCount(ctx context.Context, arg GetAuditLogCountParams) (int64, error)
// GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided
// ID.
@ -38,6 +41,10 @@ type querier interface {
GetDeploymentID(ctx context.Context) (string, error)
GetFileByHash(ctx context.Context, hash string) (File, error)
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error)
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
GetLatestAgentStat(ctx context.Context, agentID uuid.UUID) (AgentStat, error)
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
@ -73,6 +80,7 @@ type querier interface {
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserCount(ctx context.Context) (int64, error)
GetUserGroups(ctx context.Context, userID uuid.UUID) ([]Group, error)
GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error)
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error)
@ -108,10 +116,16 @@ type querier interface {
GetWorkspaces(ctx context.Context, arg GetWorkspacesParams) ([]Workspace, error)
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
InsertAgentStat(ctx context.Context, arg InsertAgentStatParams) (AgentStat, error)
// We use the organization_id as the id
// for simplicity since all users is
// every member of the org.
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertDeploymentID(ctx context.Context, value string) error
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error)
InsertGroup(ctx context.Context, arg InsertGroupParams) (Group, error)
InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error
InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error)
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
@ -134,6 +148,7 @@ type querier interface {
ParameterValues(ctx context.Context, arg ParameterValuesParams) ([]ParameterValue, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) error
UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error)
UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error)
UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
@ -163,4 +178,4 @@ type querier interface {
UpdateWorkspaceTTL(ctx context.Context, arg UpdateWorkspaceTTLParams) error
}
var _ querier = (*sqlQuerier)(nil)
var _ sqlcQuerier = (*sqlQuerier)(nil)

View File

@ -807,6 +807,328 @@ func (q *sqlQuerier) UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyPar
return err
}
const deleteGroupByID = `-- name: DeleteGroupByID :exec
DELETE FROM
groups
WHERE
id = $1
`
func (q *sqlQuerier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteGroupByID, id)
return err
}
const deleteGroupMember = `-- name: DeleteGroupMember :exec
DELETE FROM
group_members
WHERE
user_id = $1
`
func (q *sqlQuerier) DeleteGroupMember(ctx context.Context, userID uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteGroupMember, userID)
return err
}
const getAllOrganizationMembers = `-- name: GetAllOrganizationMembers :many
SELECT
users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at
FROM
users
JOIN
organization_members
ON
users.id = organization_members.user_id
WHERE
organization_members.organization_id = $1
`
func (q *sqlQuerier) GetAllOrganizationMembers(ctx context.Context, organizationID uuid.UUID) ([]User, error) {
rows, err := q.db.QueryContext(ctx, getAllOrganizationMembers, organizationID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Email,
&i.Username,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
&i.LastSeenAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getGroupByID = `-- name: GetGroupByID :one
SELECT
id, name, organization_id
FROM
groups
WHERE
id = $1
LIMIT
1
`
func (q *sqlQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error) {
row := q.db.QueryRowContext(ctx, getGroupByID, id)
var i Group
err := row.Scan(&i.ID, &i.Name, &i.OrganizationID)
return i, err
}
const getGroupByOrgAndName = `-- name: GetGroupByOrgAndName :one
SELECT
id, name, organization_id
FROM
groups
WHERE
organization_id = $1
AND
name = $2
LIMIT
1
`
type GetGroupByOrgAndNameParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error) {
row := q.db.QueryRowContext(ctx, getGroupByOrgAndName, arg.OrganizationID, arg.Name)
var i Group
err := row.Scan(&i.ID, &i.Name, &i.OrganizationID)
return i, err
}
const getGroupMembers = `-- name: GetGroupMembers :many
SELECT
users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at
FROM
users
JOIN
group_members
ON
users.id = group_members.user_id
WHERE
group_members.group_id = $1
AND
users.status = 'active'
AND
users.deleted = 'false'
`
func (q *sqlQuerier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) {
rows, err := q.db.QueryContext(ctx, getGroupMembers, groupID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Email,
&i.Username,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
&i.LastSeenAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getGroupsByOrganizationID = `-- name: GetGroupsByOrganizationID :many
SELECT
id, name, organization_id
FROM
groups
WHERE
organization_id = $1
AND
id != $1
`
func (q *sqlQuerier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, getGroupsByOrganizationID, organizationID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Group
for rows.Next() {
var i Group
if err := rows.Scan(&i.ID, &i.Name, &i.OrganizationID); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getUserGroups = `-- name: GetUserGroups :many
SELECT
groups.id, groups.name, groups.organization_id
FROM
groups
JOIN
group_members
ON
groups.id = group_members.group_id
WHERE
group_members.user_id = $1
`
func (q *sqlQuerier) GetUserGroups(ctx context.Context, userID uuid.UUID) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, getUserGroups, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Group
for rows.Next() {
var i Group
if err := rows.Scan(&i.ID, &i.Name, &i.OrganizationID); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertAllUsersGroup = `-- name: InsertAllUsersGroup :one
INSERT INTO groups (
id,
name,
organization_id
)
VALUES
( $1, 'Everyone', $1) RETURNING id, name, organization_id
`
// We use the organization_id as the id
// for simplicity since all users is
// every member of the org.
func (q *sqlQuerier) InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error) {
row := q.db.QueryRowContext(ctx, insertAllUsersGroup, organizationID)
var i Group
err := row.Scan(&i.ID, &i.Name, &i.OrganizationID)
return i, err
}
const insertGroup = `-- name: InsertGroup :one
INSERT INTO groups (
id,
name,
organization_id
)
VALUES
( $1, $2, $3) RETURNING id, name, organization_id
`
type InsertGroupParams struct {
ID uuid.UUID `db:"id" json:"id"`
Name string `db:"name" json:"name"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
}
func (q *sqlQuerier) InsertGroup(ctx context.Context, arg InsertGroupParams) (Group, error) {
row := q.db.QueryRowContext(ctx, insertGroup, arg.ID, arg.Name, arg.OrganizationID)
var i Group
err := row.Scan(&i.ID, &i.Name, &i.OrganizationID)
return i, err
}
const insertGroupMember = `-- name: InsertGroupMember :exec
INSERT INTO group_members (
user_id,
group_id
)
VALUES ( $1, $2)
`
type InsertGroupMemberParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupID uuid.UUID `db:"group_id" json:"group_id"`
}
func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error {
_, err := q.db.ExecContext(ctx, insertGroupMember, arg.UserID, arg.GroupID)
return err
}
const updateGroupByID = `-- name: UpdateGroupByID :one
UPDATE
groups
SET
name = $1
WHERE
id = $2
RETURNING id, name, organization_id
`
type UpdateGroupByIDParams struct {
Name string `db:"name" json:"name"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDParams) (Group, error) {
row := q.db.QueryRowContext(ctx, updateGroupByID, arg.Name, arg.ID)
var i Group
err := row.Scan(&i.ID, &i.Name, &i.OrganizationID)
return i, err
}
const deleteLicense = `-- name: DeleteLicense :one
DELETE
FROM licenses
@ -2231,7 +2553,7 @@ func (q *sqlQuerier) InsertDeploymentID(ctx context.Context, value string) error
const getTemplateByID = `-- name: GetTemplateByID :one
SELECT
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon, user_acl, group_acl
FROM
templates
WHERE
@ -2257,13 +2579,15 @@ func (q *sqlQuerier) GetTemplateByID(ctx context.Context, id uuid.UUID) (Templat
&i.MinAutostartInterval,
&i.CreatedBy,
&i.Icon,
&i.userACL,
&i.groupACL,
)
return i, err
}
const getTemplateByOrganizationAndName = `-- name: GetTemplateByOrganizationAndName :one
SELECT
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon, user_acl, group_acl
FROM
templates
WHERE
@ -2297,12 +2621,14 @@ func (q *sqlQuerier) GetTemplateByOrganizationAndName(ctx context.Context, arg G
&i.MinAutostartInterval,
&i.CreatedBy,
&i.Icon,
&i.userACL,
&i.groupACL,
)
return i, err
}
const getTemplates = `-- name: GetTemplates :many
SELECT id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon FROM templates
SELECT id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon, user_acl, group_acl FROM templates
ORDER BY (name, id) ASC
`
@ -2329,6 +2655,8 @@ func (q *sqlQuerier) GetTemplates(ctx context.Context) ([]Template, error) {
&i.MinAutostartInterval,
&i.CreatedBy,
&i.Icon,
&i.userACL,
&i.groupACL,
); err != nil {
return nil, err
}
@ -2345,7 +2673,7 @@ func (q *sqlQuerier) GetTemplates(ctx context.Context) ([]Template, error) {
const getTemplatesWithFilter = `-- name: GetTemplatesWithFilter :many
SELECT
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon, user_acl, group_acl
FROM
templates
WHERE
@ -2407,6 +2735,8 @@ func (q *sqlQuerier) GetTemplatesWithFilter(ctx context.Context, arg GetTemplate
&i.MinAutostartInterval,
&i.CreatedBy,
&i.Icon,
&i.userACL,
&i.groupACL,
); err != nil {
return nil, err
}
@ -2438,7 +2768,7 @@ INSERT INTO
icon
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon, user_acl, group_acl
`
type InsertTemplateParams struct {
@ -2486,6 +2816,8 @@ func (q *sqlQuerier) InsertTemplate(ctx context.Context, arg InsertTemplateParam
&i.MinAutostartInterval,
&i.CreatedBy,
&i.Icon,
&i.userACL,
&i.groupACL,
)
return i, err
}
@ -2545,7 +2877,7 @@ SET
WHERE
id = $1
RETURNING
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon
id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, max_ttl, min_autostart_interval, created_by, icon, user_acl, group_acl
`
type UpdateTemplateMetaByIDParams struct {
@ -2583,6 +2915,8 @@ func (q *sqlQuerier) UpdateTemplateMetaByID(ctx context.Context, arg UpdateTempl
&i.MinAutostartInterval,
&i.CreatedBy,
&i.Icon,
&i.userACL,
&i.groupACL,
)
return i, err
}
@ -3071,16 +3405,36 @@ SELECT
-- status is used to enforce 'suspended' users, as all roles are ignored
-- when suspended.
id, username, status,
-- All user roles, including their org roles.
array_cat(
-- All users are members
array_append(users.rbac_roles, 'member'),
-- All org_members get the org-member role for their orgs
array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[]
AS roles
array_append(users.rbac_roles, 'member'),
(
SELECT
array_agg(org_roles)
FROM
organization_members,
-- All org_members get the org-member role for their orgs
unnest(
array_append(roles, 'organization-member:' || organization_members.organization_id::text)
) AS org_roles
WHERE
user_id = users.id
)
) :: text[] AS roles,
-- All groups the user is in.
(
SELECT
array_agg(
group_members.group_id :: text
)
FROM
group_members
WHERE
user_id = users.id
) :: text[] AS groups
FROM
users
LEFT JOIN organization_members
ON id = user_id
WHERE
id = $1
`
@ -3090,6 +3444,7 @@ type GetAuthorizationUserRolesRow struct {
Username string `db:"username" json:"username"`
Status UserStatus `db:"status" json:"status"`
Roles []string `db:"roles" json:"roles"`
Groups []string `db:"groups" json:"groups"`
}
// This function returns roles for authorization purposes. Implied member roles
@ -3102,6 +3457,7 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
&i.Username,
&i.Status,
pq.Array(&i.Roles),
pq.Array(&i.Groups),
)
return i, err
}
@ -3135,7 +3491,7 @@ func (q *sqlQuerier) GetUserByEmailOrUsername(ctx context.Context, arg GetUserBy
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3166,7 +3522,7 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3285,7 +3641,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3328,7 +3684,7 @@ func (q *sqlQuerier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3364,14 +3720,14 @@ VALUES
`
type InsertUserParams struct {
ID uuid.UUID `db:"id" json:"id"`
Email string `db:"email" json:"email"`
Username string `db:"username" json:"username"`
HashedPassword []byte `db:"hashed_password" json:"hashed_password"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
RBACRoles []string `db:"rbac_roles" json:"rbac_roles"`
LoginType LoginType `db:"login_type" json:"login_type"`
ID uuid.UUID `db:"id" json:"id"`
Email string `db:"email" json:"email"`
Username string `db:"username" json:"username"`
HashedPassword []byte `db:"hashed_password" json:"hashed_password"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
RBACRoles pq.StringArray `db:"rbac_roles" json:"rbac_roles"`
LoginType LoginType `db:"login_type" json:"login_type"`
}
func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User, error) {
@ -3382,7 +3738,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
arg.HashedPassword,
arg.CreatedAt,
arg.UpdatedAt,
pq.Array(arg.RBACRoles),
arg.RBACRoles,
arg.LoginType,
)
var i User
@ -3394,7 +3750,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3468,7 +3824,7 @@ func (q *sqlQuerier) UpdateUserLastSeenAt(ctx context.Context, arg UpdateUserLas
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3514,7 +3870,7 @@ func (q *sqlQuerier) UpdateUserProfile(ctx context.Context, arg UpdateUserProfil
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3550,7 +3906,7 @@ func (q *sqlQuerier) UpdateUserRoles(ctx context.Context, arg UpdateUserRolesPar
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
@ -3586,7 +3942,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
pq.Array(&i.RBACRoles),
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,

View File

@ -0,0 +1,122 @@
-- name: GetGroupByID :one
SELECT
*
FROM
groups
WHERE
id = $1
LIMIT
1;
-- name: GetGroupByOrgAndName :one
SELECT
*
FROM
groups
WHERE
organization_id = $1
AND
name = $2
LIMIT
1;
-- name: GetUserGroups :many
SELECT
groups.*
FROM
groups
JOIN
group_members
ON
groups.id = group_members.group_id
WHERE
group_members.user_id = $1;
-- name: GetGroupMembers :many
SELECT
users.*
FROM
users
JOIN
group_members
ON
users.id = group_members.user_id
WHERE
group_members.group_id = $1
AND
users.status = 'active'
AND
users.deleted = 'false';
-- name: GetAllOrganizationMembers :many
SELECT
users.*
FROM
users
JOIN
organization_members
ON
users.id = organization_members.user_id
WHERE
organization_members.organization_id = $1;
-- name: GetGroupsByOrganizationID :many
SELECT
*
FROM
groups
WHERE
organization_id = $1
AND
id != $1;
-- name: InsertGroup :one
INSERT INTO groups (
id,
name,
organization_id
)
VALUES
( $1, $2, $3) RETURNING *;
-- We use the organization_id as the id
-- for simplicity since all users is
-- every member of the org.
-- name: InsertAllUsersGroup :one
INSERT INTO groups (
id,
name,
organization_id
)
VALUES
( sqlc.arg(organization_id), 'Everyone', sqlc.arg(organization_id)) RETURNING *;
-- name: UpdateGroupByID :one
UPDATE
groups
SET
name = $1
WHERE
id = $2
RETURNING *;
-- name: InsertGroupMember :exec
INSERT INTO group_members (
user_id,
group_id
)
VALUES ( $1, $2);
-- name: DeleteGroupMember :exec
DELETE FROM
group_members
WHERE
user_id = $1;
-- name: DeleteGroupByID :exec
DELETE FROM
groups
WHERE
id = $1;

View File

@ -178,15 +178,35 @@ SELECT
-- status is used to enforce 'suspended' users, as all roles are ignored
-- when suspended.
id, username, status,
-- All user roles, including their org roles.
array_cat(
-- All users are members
array_append(users.rbac_roles, 'member'),
-- All org_members get the org-member role for their orgs
array_append(organization_members.roles, 'organization-member:'||organization_members.organization_id::text)) :: text[]
AS roles
array_append(users.rbac_roles, 'member'),
(
SELECT
array_agg(org_roles)
FROM
organization_members,
-- All org_members get the org-member role for their orgs
unnest(
array_append(roles, 'organization-member:' || organization_members.organization_id::text)
) AS org_roles
WHERE
user_id = users.id
)
) :: text[] AS roles,
-- All groups the user is in.
(
SELECT
array_agg(
group_members.group_id :: text
)
FROM
group_members
WHERE
user_id = users.id
) :: text[] AS groups
FROM
users
LEFT JOIN organization_members
ON id = user_id
WHERE
id = @user_id;

View File

@ -16,6 +16,10 @@ packages:
# deleted after generation.
output_db_file_name: db_tmp.go
overrides:
- column: "users.rbac_roles"
go_type: "github.com/lib/pq.StringArray"
rename:
api_key: APIKey
api_key_scope: APIKeyScope
@ -35,3 +39,5 @@ rename:
ip_addresses: IPAddresses
ids: IDs
jwt: JWT
user_acl: userACL
group_acl: groupACL

View File

@ -6,6 +6,8 @@ type UniqueConstraint string
// UniqueConstraint enums.
const (
UniqueGroupMembersUserIDGroupIDKey UniqueConstraint = "group_members_user_id_group_id_key" // ALTER TABLE ONLY group_members ADD CONSTRAINT group_members_user_id_group_id_key UNIQUE (user_id, group_id);
UniqueGroupsNameOrganizationIDKey UniqueConstraint = "groups_name_organization_id_key" // ALTER TABLE ONLY groups ADD CONSTRAINT groups_name_organization_id_key UNIQUE (name, organization_id);
UniqueLicensesJWTKey UniqueConstraint = "licenses_jwt_key" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_jwt_key UNIQUE (jwt);
UniqueParameterSchemasJobIDNameKey UniqueConstraint = "parameter_schemas_job_id_name_key" // ALTER TABLE ONLY parameter_schemas ADD CONSTRAINT parameter_schemas_job_id_name_key UNIQUE (job_id, name);
UniqueParameterValuesScopeIDNameKey UniqueConstraint = "parameter_values_scope_id_name_key" // ALTER TABLE ONLY parameter_values ADD CONSTRAINT parameter_values_scope_id_name_key UNIQUE (scope_id, name);

View File

@ -23,7 +23,7 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) {
apiKey := httpmw.APIKey(r)
// This requires the site wide action to create files.
// Once created, a user can read their own files uploaded
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceFile) {
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceFile.WithOwner(apiKey.UserID.String())) {
httpapi.Forbidden(rw)
return
}

View File

@ -54,6 +54,7 @@ type Authorization struct {
ID uuid.UUID
Username string
Roles []string
Groups []string
Scope database.APIKeyScope
}
@ -360,6 +361,7 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
Username: roles.Username,
Roles: roles.Roles,
Scope: key.Scope,
Groups: roles.Groups,
})
next.ServeHTTP(rw, r.WithContext(ctx))

View File

@ -4,23 +4,22 @@ import (
"context"
"crypto/sha256"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/tabbed/pqtype"
"github.com/coder/coder/coderd/database"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
)
func TestExtractUserRoles(t *testing.T) {
@ -71,14 +70,49 @@ func TestExtractUserRoles(t *testing.T) {
return user, append(roles, append(orgRoles, rbac.RoleMember(), rbac.RoleOrgMember(org.ID))...), token
},
},
{
Name: "MultipleOrgMember",
AddUser: func(db database.Store) (database.User, []string, string) {
roles := []string{}
user, token := addUser(t, db, roles...)
roles = append(roles, rbac.RoleMember())
for i := 0; i < 3; i++ {
organization, err := db.InsertOrganization(context.Background(), database.InsertOrganizationParams{
ID: uuid.New(),
Name: fmt.Sprintf("testorg%d", i),
Description: "test",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
require.NoError(t, err)
orgRoles := []string{}
if i%2 == 0 {
orgRoles = append(orgRoles, rbac.RoleOrgAdmin(organization.ID))
}
_, err = db.InsertOrganizationMember(context.Background(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: user.ID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Roles: orgRoles,
})
require.NoError(t, err)
roles = append(roles, orgRoles...)
roles = append(roles, rbac.RoleOrgMember(organization.ID))
}
return user, roles, token
},
},
}
for _, c := range testCases {
c := c
t.Run(c.Name, func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
db, _ = dbtestutil.NewDB(t)
user, expRoles, token = c.AddUser(db)
rw = httptest.NewRecorder()
rtr = chi.NewRouter()
@ -118,6 +152,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
Email: "admin@email.com",
Username: "admin",
RBACRoles: roles,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
@ -129,6 +164,13 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
Scope: database.APIKeyScopeAll,
IPAddress: pqtype.Inet{
IPNet: net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.IPMask{0, 0, 0, 0},
},
Valid: true,
},
})
require.NoError(t, err)

View File

@ -0,0 +1,56 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
type groupParamContextKey struct{}
// GroupParam returns the group extracted via the ExtraGroupParam middleware.
func GroupParam(r *http.Request) database.Group {
group, ok := r.Context().Value(groupParamContextKey{}).(database.Group)
if !ok {
panic("developer error: group param middleware not provided")
}
return group
}
// ExtraGroupParam grabs a group from the "group" URL parameter.
func ExtractGroupParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
groupID, parsed := parseUUID(rw, r, "group")
if !parsed {
return
}
group, err := db.GetGroupByID(r.Context(), groupID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching group.",
Detail: err.Error(),
})
return
}
ctx = context.WithValue(ctx, groupParamContextKey{}, group)
chi.RouteContext(ctx).URLParams.Add("organization", group.OrganizationID.String())
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,103 @@
package httpmw_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/testutil"
)
func TestGroupParam(t *testing.T) {
t.Parallel()
setup := func(t *testing.T) (database.Store, database.Group) {
t.Helper()
ctx, _ := testutil.Context(t)
db := databasefake.New()
orgID := uuid.New()
organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
ID: orgID,
Name: "banana",
Description: "wowie",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
group, err := db.InsertGroup(ctx, database.InsertGroupParams{
ID: uuid.New(),
Name: "yeww",
OrganizationID: organization.ID,
})
require.NoError(t, err)
return db, group
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
var (
db, group = setup(t)
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
)
router := chi.NewRouter()
router.Use(httpmw.ExtractGroupParam(db))
router.Get("/", func(w http.ResponseWriter, r *http.Request) {
g := httpmw.GroupParam(r)
require.Equal(t, group, g)
w.WriteHeader(http.StatusOK)
})
rctx := chi.NewRouteContext()
rctx.URLParams.Add("group", group.ID.String())
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
router.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
var (
db, group = setup(t)
r = httptest.NewRequest("GET", "/", nil)
w = httptest.NewRecorder()
)
router := chi.NewRouter()
router.Use(httpmw.ExtractGroupParam(db))
router.Get("/", func(w http.ResponseWriter, r *http.Request) {
g := httpmw.GroupParam(r)
require.Equal(t, group, g)
w.WriteHeader(http.StatusOK)
})
rctx := chi.NewRouteContext()
rctx.URLParams.Add("group", uuid.NewString())
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
router.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
}

View File

@ -7,6 +7,7 @@ import (
"net/http"
"github.com/go-chi/chi/v5"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
@ -46,8 +47,20 @@ func ExtractTemplateVersionParam(db database.Store) func(http.Handler) http.Hand
return
}
template, err := db.GetTemplateByID(r.Context(), templateVersion.TemplateID.UUID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching template.",
Detail: err.Error(),
})
return
}
ctx = context.WithValue(ctx, templateVersionParamContextKey{}, templateVersion)
chi.RouteContext(ctx).URLParams.Add("organization", templateVersion.OrganizationID.String())
ctx = context.WithValue(ctx, templateParamContextKey{}, template)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}

View File

@ -60,8 +60,8 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
}
var organization database.Organization
err = api.Database.InTx(func(store database.Store) error {
organization, err = store.InsertOrganization(ctx, database.InsertOrganizationParams{
err = api.Database.InTx(func(tx database.Store) error {
organization, err = tx.InsertOrganization(ctx, database.InsertOrganizationParams{
ID: uuid.New(),
Name: req.Name,
CreatedAt: database.Now(),
@ -70,7 +70,7 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
if err != nil {
return xerrors.Errorf("create organization: %w", err)
}
_, err = store.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
_, err = tx.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: apiKey.UserID,
CreatedAt: database.Now(),
@ -82,6 +82,11 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
if err != nil {
return xerrors.Errorf("create organization admin: %w", err)
}
_, err = tx.InsertAllUsersGroup(ctx, organization.ID)
if err != nil {
return xerrors.Errorf("create %q group: %w", database.AllUsersGroup, err)
}
return nil
})
if err != nil {

View File

@ -219,7 +219,19 @@ func (api *API) parameterRBACResource(rw http.ResponseWriter, r *http.Request, s
case database.ParameterScopeWorkspace:
resource, err = api.Database.GetWorkspaceByID(ctx, scopeID)
case database.ParameterScopeImportJob:
resource, err = api.Database.GetTemplateVersionByJobID(ctx, scopeID)
// I hate myself.
var version database.TemplateVersion
version, err = api.Database.GetTemplateVersionByJobID(ctx, scopeID)
if err != nil {
break
}
var template database.Template
template, err = api.Database.GetTemplateByID(ctx, version.TemplateID.UUID)
if err != nil {
break
}
resource = version.RBACObject(template)
case database.ParameterScopeTemplate:
resource, err = api.Database.GetTemplateByID(ctx, scopeID)
default:

View File

@ -3,7 +3,6 @@ package rbac
import (
"context"
_ "embed"
"fmt"
"sync"
"github.com/open-policy-agent/opa/rego"
@ -15,8 +14,8 @@ import (
)
type Authorizer interface {
ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, object Object) error
PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, objectType string) (PreparedAuthorized, error)
ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, groups []string, action Action, object Object) error
PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, groups []string, action Action, objectType string) (PreparedAuthorized, error)
}
type PreparedAuthorized interface {
@ -27,7 +26,7 @@ type PreparedAuthorized interface {
// Filter takes in a list of objects, and will filter the list removing all
// the elements the subject does not have permission for. All objects must be
// of the same type.
func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, subjRoles []string, scope Scope, action Action, objects []O) ([]O, error) {
func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, subjRoles []string, scope Scope, groups []string, action Action, objects []O) ([]O, error) {
ctx, span := tracing.StartSpan(ctx, trace.WithAttributes(
attribute.String("subject_id", subjID),
attribute.StringSlice("subject_roles", subjRoles),
@ -52,7 +51,7 @@ func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, sub
if rbacObj.Type != objectType {
return nil, xerrors.Errorf("object types must be uniform across the set (%s), found %s", objectType, rbacObj)
}
err := auth.ByRoleName(ctx, subjID, subjRoles, scope, action, o.RBACObject())
err := auth.ByRoleName(ctx, subjID, subjRoles, scope, groups, action, o.RBACObject())
if err == nil {
filtered = append(filtered, o)
}
@ -60,7 +59,7 @@ func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, sub
return filtered, nil
}
prepared, err := auth.PrepareByRoleName(ctx, subjID, subjRoles, scope, action, objectType)
prepared, err := auth.PrepareByRoleName(ctx, subjID, subjRoles, scope, groups, action, objectType)
if err != nil {
return nil, xerrors.Errorf("prepare: %w", err)
}
@ -95,21 +94,11 @@ var (
query rego.PreparedEvalQuery
)
const (
rolesOkCheck = "role_ok"
scopeOkCheck = "scope_ok"
)
func NewAuthorizer() *RegoAuthorizer {
queryOnce.Do(func() {
var err error
query, err = rego.New(
// Bind the results to 2 variables for easy checking later.
rego.Query(
fmt.Sprintf("%s := data.authz.role_allow "+
"%s := data.authz.scope_allow",
rolesOkCheck, scopeOkCheck),
),
rego.Query("data.authz.allow"),
rego.Module("policy.rego", policy),
).PrepareForEval(context.Background())
if err != nil {
@ -120,15 +109,16 @@ func NewAuthorizer() *RegoAuthorizer {
}
type authSubject struct {
ID string `json:"id"`
Roles []Role `json:"roles"`
Scope Role `json:"scope"`
ID string `json:"id"`
Roles []Role `json:"roles"`
Groups []string `json:"groups"`
Scope Role `json:"scope"`
}
// ByRoleName will expand all roleNames into roles before calling Authorize().
// This is the function intended to be used outside this package.
// The role is fetched from the builtin map located in memory.
func (a RegoAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, object Object) error {
func (a RegoAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, groups []string, action Action, object Object) error {
roles, err := RolesByNames(roleNames)
if err != nil {
return err
@ -139,7 +129,7 @@ func (a RegoAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNa
return err
}
err = a.Authorize(ctx, subjectID, roles, scopeRole, action, object)
err = a.Authorize(ctx, subjectID, roles, scopeRole, groups, action, object)
if err != nil {
return err
}
@ -149,12 +139,16 @@ func (a RegoAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNa
// Authorize allows passing in custom Roles.
// This is really helpful for unit testing, as we can create custom roles to exercise edge cases.
func (a RegoAuthorizer) Authorize(ctx context.Context, subjectID string, roles []Role, scope Role, action Action, object Object) error {
func (a RegoAuthorizer) Authorize(ctx context.Context, subjectID string, roles []Role, scope Role, groups []string, action Action, object Object) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
input := map[string]interface{}{
"subject": authSubject{
ID: subjectID,
Roles: roles,
Scope: scope,
ID: subjectID,
Roles: roles,
Groups: groups,
Scope: scope,
},
"object": object,
"action": action,
@ -165,37 +159,19 @@ func (a RegoAuthorizer) Authorize(ctx context.Context, subjectID string, roles [
return ForbiddenWithInternal(xerrors.Errorf("eval rego: %w", err), input, results)
}
// We expect only the 2 bindings for scopes and roles checks.
if len(results) == 1 && len(results[0].Bindings) == 2 {
roleCheck, ok := results[0].Bindings[rolesOkCheck].(bool)
if !ok || !roleCheck {
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), input, results)
}
scopeCheck, ok := results[0].Bindings[scopeOkCheck].(bool)
if !ok || !scopeCheck {
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), input, results)
}
// This is purely defensive programming. The two above checks already
// check for 'true' expressions. This is just a sanity check to make
// sure we don't add non-boolean expressions to our query.
// This is super cheap to do, and just adds in some extra safety for
// programmer error.
for _, exp := range results[0].Expressions {
if b, ok := exp.Value.(bool); !ok || !b {
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), input, results)
}
}
return nil
if !results.Allowed() {
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), input, results)
}
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), input, results)
return nil
}
// Prepare will partially execute the rego policy leaving the object fields unknown (except for the type).
// This will vastly speed up performance if batch authorization on the same type of objects is needed.
func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Role, scope Role, action Action, objectType string) (*PartialAuthorizer, error) {
auth, err := newPartialAuthorizer(ctx, subjectID, roles, scope, action, objectType)
func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Role, scope Role, groups []string, action Action, objectType string) (*PartialAuthorizer, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
auth, err := newPartialAuthorizer(ctx, subjectID, roles, scope, groups, action, objectType)
if err != nil {
return nil, xerrors.Errorf("new partial authorizer: %w", err)
}
@ -203,7 +179,10 @@ func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Rol
return auth, nil
}
func (a RegoAuthorizer) PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, objectType string) (PreparedAuthorized, error) {
func (a RegoAuthorizer) PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, groups []string, action Action, objectType string) (PreparedAuthorized, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
roles, err := RolesByNames(roleNames)
if err != nil {
return nil, err
@ -214,5 +193,5 @@ func (a RegoAuthorizer) PrepareByRoleName(ctx context.Context, subjectID string,
return nil, err
}
return a.Prepare(ctx, subjectID, roles, scopeRole, action, objectType)
return a.Prepare(ctx, subjectID, roles, scopeRole, groups, action, objectType)
}

View File

@ -19,8 +19,9 @@ type subject struct {
// For the unit test we want to pass in the roles directly, instead of just
// by name. This allows us to test custom roles that do not exist in the product,
// but test edge cases of the implementation.
Roles []Role `json:"roles"`
Scope Role `json:"scope"`
Roles []Role `json:"roles"`
Groups []string `json:"groups"`
Scope Role `json:"scope"`
}
type fakeObject struct {
@ -41,7 +42,8 @@ func (w fakeObject) RBACObject() Object {
func TestFilterError(t *testing.T) {
t.Parallel()
auth := NewAuthorizer()
_, err := Filter(context.Background(), auth, uuid.NewString(), []string{}, ScopeAll, ActionRead, []Object{ResourceUser, ResourceWorkspace})
_, err := Filter(context.Background(), auth, uuid.NewString(), []string{}, ScopeAll, []string{}, ActionRead, []Object{ResourceUser, ResourceWorkspace})
require.ErrorContains(t, err, "object types must be uniform")
}
@ -169,7 +171,7 @@ func TestFilter(t *testing.T) {
var allowedCount int
for i, obj := range localObjects {
obj.Type = tc.ObjectType
err := auth.ByRoleName(ctx, tc.SubjectID, tc.Roles, scope, ActionRead, obj.RBACObject())
err := auth.ByRoleName(ctx, tc.SubjectID, tc.Roles, scope, []string{}, ActionRead, obj.RBACObject())
obj.Allowed = err == nil
if err == nil {
allowedCount++
@ -178,7 +180,7 @@ func TestFilter(t *testing.T) {
}
// Run by filter
list, err := Filter(ctx, auth, tc.SubjectID, tc.Roles, scope, tc.Action, localObjects)
list, err := Filter(ctx, auth, tc.SubjectID, tc.Roles, scope, []string{}, tc.Action, localObjects)
require.NoError(t, err)
require.Equal(t, allowedCount, len(list), "expected number of allowed")
for _, obj := range list {
@ -193,15 +195,82 @@ func TestAuthorizeDomain(t *testing.T) {
t.Parallel()
defOrg := uuid.New()
unuseID := uuid.New()
allUsersGroup := "Everyone"
user := subject{
UserID: "me",
Scope: must(ScopeRole(ScopeAll)),
Groups: []string{allUsersGroup},
Roles: []Role{
must(RoleByName(RoleMember())),
must(RoleByName(RoleOrgMember(defOrg))),
},
}
testAuthorize(t, "UserACLList", user, []authTestCase{
{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]Action{
user.UserID: allActions(),
}),
actions: allActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]Action{
user.UserID: {WildcardSymbol},
}),
actions: allActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(unuseID).WithACLUserList(map[string][]Action{
user.UserID: {ActionRead, ActionUpdate},
}),
actions: []Action{ActionCreate, ActionDelete},
allow: false,
},
{
// By default users cannot update templates
resource: ResourceTemplate.InOrg(defOrg).WithACLUserList(map[string][]Action{
user.UserID: {ActionUpdate},
}),
actions: []Action{ActionUpdate},
allow: true,
},
})
testAuthorize(t, "GroupACLList", user, []authTestCase{
{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]Action{
allUsersGroup: allActions(),
}),
actions: allActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]Action{
allUsersGroup: {WildcardSymbol},
}),
actions: allActions(),
allow: true,
},
{
resource: ResourceWorkspace.WithOwner(unuseID.String()).InOrg(defOrg).WithGroupACL(map[string][]Action{
allUsersGroup: {ActionRead, ActionUpdate},
}),
actions: []Action{ActionCreate, ActionDelete},
allow: false,
},
{
// By default users cannot update templates
resource: ResourceTemplate.InOrg(defOrg).WithGroupACL(map[string][]Action{
allUsersGroup: {ActionUpdate},
}),
actions: []Action{ActionUpdate},
allow: true,
},
})
testAuthorize(t, "Member", user, []authTestCase{
// Org + me
{resource: ResourceWorkspace.InOrg(defOrg).WithOwner(user.UserID), actions: allActions(), allow: true},
@ -743,9 +812,6 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
for _, cases := range sets {
for i, c := range cases {
c := c
if c.resource.Type != "application_connect" {
continue
}
caseName := fmt.Sprintf("%s/%d", name, i)
t.Run(caseName, func(t *testing.T) {
t.Parallel()
@ -753,23 +819,21 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)
authError := authorizer.Authorize(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource)
authError := authorizer.Authorize(ctx, subject.UserID, subject.Roles, subject.Scope, subject.Groups, a, c.resource)
d, _ := json.Marshal(map[string]interface{}{
"subject": subject,
"object": c.resource,
"action": a,
})
// Logging only
t.Logf("input: %s", string(d))
if authError != nil {
var uerr *UnauthorizedError
xerrors.As(authError, &uerr)
d, _ := json.Marshal(uerr.Input())
t.Logf("input: %s", string(d))
t.Logf("internal error: %+v", uerr.Internal().Error())
t.Logf("output: %+v", uerr.Output())
} else {
d, _ := json.Marshal(map[string]interface{}{
"subject": subject,
"object": c.resource,
"action": a,
})
t.Log(string(d))
}
if c.allow {
@ -778,19 +842,17 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes
assert.Error(t, authError, "expected unauthorized")
}
partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, a, c.resource.Type)
partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, subject.Scope, subject.Groups, a, c.resource.Type)
require.NoError(t, err, "make prepared authorizer")
// Ensure the partial can compile to a SQL clause.
// This does not guarantee that the clause is valid SQL.
_, err = Compile(partialAuthz.partialQueries)
_, err = Compile(partialAuthz)
require.NoError(t, err, "compile prepared authorizer")
// Also check the rego policy can form a valid partial query result.
// This ensures we can convert the queries into SQL WHERE clauses in the future.
// If this function returns 'Support' sections, then we cannot convert the query into SQL.
d, _ := json.Marshal(partialAuthz.input)
t.Logf("input: %s", string(d))
for _, q := range partialAuthz.partialQueries.Queries {
t.Logf("query: %+v", q.String())
}

View File

@ -63,8 +63,8 @@ var (
return Role{
Name: owner,
DisplayName: "Owner",
Site: permissions(map[Object][]Action{
ResourceWildcard: {WildcardSymbol},
Site: permissions(map[string][]Action{
ResourceWildcard.Type: {WildcardSymbol},
}),
}
},
@ -74,15 +74,15 @@ var (
return Role{
Name: member,
DisplayName: "",
Site: permissions(map[Object][]Action{
Site: permissions(map[string][]Action{
// All users can read all other users and know they exist.
ResourceUser: {ActionRead},
ResourceRoleAssignment: {ActionRead},
ResourceUser.Type: {ActionRead},
ResourceRoleAssignment.Type: {ActionRead},
// All users can see the provisioner daemons.
ResourceProvisionerDaemon: {ActionRead},
ResourceProvisionerDaemon.Type: {ActionRead},
}),
User: permissions(map[Object][]Action{
ResourceWildcard: {WildcardSymbol},
User: permissions(map[string][]Action{
ResourceWildcard.Type: {WildcardSymbol},
}),
}
},
@ -94,11 +94,11 @@ var (
return Role{
Name: auditor,
DisplayName: "Auditor",
Site: permissions(map[Object][]Action{
Site: permissions(map[string][]Action{
// Should be able to read all template details, even in orgs they
// are not in.
ResourceTemplate: {ActionRead},
ResourceAuditLog: {ActionRead},
ResourceTemplate.Type: {ActionRead},
ResourceAuditLog.Type: {ActionRead},
}),
}
},
@ -107,13 +107,13 @@ var (
return Role{
Name: templateAdmin,
DisplayName: "Template Admin",
Site: permissions(map[Object][]Action{
ResourceTemplate: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
Site: permissions(map[string][]Action{
ResourceTemplate.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
// CRUD all files, even those they did not upload.
ResourceFile: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
ResourceWorkspace: {ActionRead},
ResourceFile.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
ResourceWorkspace.Type: {ActionRead},
// CRUD to provisioner daemons for now.
ResourceProvisionerDaemon: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
ResourceProvisionerDaemon.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
}),
}
},
@ -122,11 +122,11 @@ var (
return Role{
Name: userAdmin,
DisplayName: "User Admin",
Site: permissions(map[Object][]Action{
ResourceRoleAssignment: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
ResourceUser: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
Site: permissions(map[string][]Action{
ResourceRoleAssignment.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
ResourceUser.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
// Full perms to manage org members
ResourceOrganizationMember: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
ResourceOrganizationMember.Type: {ActionCreate, ActionRead, ActionUpdate, ActionDelete},
}),
}
},
@ -168,13 +168,12 @@ var (
Action: ActionRead,
},
{
// All org members can read templates in the org
ResourceType: ResourceTemplate.Type,
// Can read available roles.
ResourceType: ResourceOrgRoleAssignment.Type,
Action: ActionRead,
},
{
// Can read available roles.
ResourceType: ResourceOrgRoleAssignment.Type,
ResourceType: ResourceGroup.Type,
Action: ActionRead,
},
},
@ -390,14 +389,14 @@ func roleSplit(role string) (name string, orgID string, err error) {
// permissions is just a helper function to make building roles that list out resources
// and actions a bit easier.
func permissions(perms map[Object][]Action) []Permission {
func permissions(perms map[string][]Action) []Permission {
list := make([]Permission, 0, len(perms))
for k, actions := range perms {
for _, act := range actions {
act := act
list = append(list, Permission{
Negate: false,
ResourceType: k.Type,
ResourceType: k,
Action: act,
})
}

View File

@ -32,6 +32,7 @@ func BenchmarkRBACFilter(b *testing.B) {
benchCases := []struct {
Name string
Roles []string
Groups []string
UserID uuid.UUID
Scope rbac.Scope
}{
@ -87,7 +88,7 @@ func BenchmarkRBACFilter(b *testing.B) {
b.Run(c.Name, func(b *testing.B) {
objects := benchmarkSetup(orgs, users, b.N)
b.ResetTimer()
allowed, err := rbac.Filter(context.Background(), authorizer, c.UserID.String(), c.Roles, c.Scope, rbac.ActionRead, objects)
allowed, err := rbac.Filter(context.Background(), authorizer, c.UserID.String(), c.Roles, c.Scope, c.Groups, rbac.ActionRead, objects)
require.NoError(b, err)
var _ = allowed
})
@ -96,11 +97,17 @@ func BenchmarkRBACFilter(b *testing.B) {
func benchmarkSetup(orgs []uuid.UUID, users []uuid.UUID, size int) []rbac.Object {
// Create a "random" but deterministic set of objects.
aclList := map[string][]rbac.Action{
uuid.NewString(): {rbac.ActionRead, rbac.ActionUpdate},
uuid.NewString(): {rbac.ActionCreate},
}
objectList := make([]rbac.Object, size)
for i := range objectList {
objectList[i] = rbac.ResourceWorkspace.
InOrg(orgs[i%len(orgs)]).
WithOwner(users[i%len(users)].String())
WithOwner(users[i%len(users)].String()).
WithACLUserList(aclList).
WithGroupACL(aclList)
}
return objectList
@ -111,6 +118,7 @@ type authSubject struct {
Name string
UserID string
Roles []string
Groups []string
}
func TestRolePermissions(t *testing.T) {
@ -227,8 +235,8 @@ func TestRolePermissions(t *testing.T) {
Actions: []rbac.Action{rbac.ActionRead},
Resource: rbac.ResourceTemplate.InOrg(orgID),
AuthorizeMap: map[bool][]authSubject{
true: {owner, orgMemberMe, orgAdmin, templateAdmin},
false: {memberMe, otherOrgAdmin, otherOrgMember, userAdmin},
true: {owner, orgAdmin, templateAdmin},
false: {memberMe, otherOrgAdmin, otherOrgMember, userAdmin, orgMemberMe},
},
},
{
@ -242,7 +250,7 @@ func TestRolePermissions(t *testing.T) {
},
{
Name: "MyFile",
Actions: []rbac.Action{rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
Actions: []rbac.Action{rbac.ActionCreate, rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
Resource: rbac.ResourceFile.WithOwner(currentUser.String()),
AuthorizeMap: map[bool][]authSubject{
true: {owner, memberMe, orgMemberMe, templateAdmin},
@ -348,6 +356,19 @@ func TestRolePermissions(t *testing.T) {
false: {memberMe, otherOrgAdmin, otherOrgMember, templateAdmin},
},
},
{
Name: "AllUsersGroupACL",
Actions: []rbac.Action{rbac.ActionRead},
Resource: rbac.ResourceTemplate.InOrg(orgID).WithGroupACL(
map[string][]rbac.Action{
orgID.String(): {rbac.ActionRead},
}),
AuthorizeMap: map[bool][]authSubject{
true: {owner, orgAdmin, orgMemberMe, templateAdmin},
false: {memberMe, otherOrgAdmin, otherOrgMember, userAdmin},
},
},
}
for _, c := range testCases {
@ -365,7 +386,7 @@ func TestRolePermissions(t *testing.T) {
delete(remainingSubjs, subj.Name)
msg := fmt.Sprintf("%s as %q doing %q on %q", c.Name, subj.Name, action, c.Resource.Type)
// TODO: scopey
err := auth.ByRoleName(context.Background(), subj.UserID, subj.Roles, rbac.ScopeAll, action, c.Resource)
err := auth.ByRoleName(context.Background(), subj.UserID, subj.Roles, rbac.ScopeAll, subj.Groups, action, c.Resource)
if result {
assert.NoError(t, err, fmt.Sprintf("Should pass: %s", msg))
} else {

View File

@ -54,6 +54,14 @@ var (
Type: "template",
}
// ResourceGroup CRUD. Org admins only.
// create/delete = Make or delete a new group.
// update = Update the name or members of a group.
// read = Read groups and their members.
ResourceGroup = Object{
Type: "group",
}
ResourceFile = Object{
Type: "file",
}
@ -152,7 +160,9 @@ type Object struct {
// Type is "workspace", "project", "app", etc
Type string `json:"type"`
// TODO: SharedUsers?
ACLUserList map[string][]Action ` json:"acl_user_list"`
ACLGroupList map[string][]Action ` json:"acl_group_list"`
}
func (z Object) RBACObject() Object {
@ -162,26 +172,53 @@ func (z Object) RBACObject() Object {
// All returns an object matching all resources of the same type.
func (z Object) All() Object {
return Object{
Owner: "",
OrgID: "",
Type: z.Type,
Owner: "",
OrgID: "",
Type: z.Type,
ACLUserList: map[string][]Action{},
ACLGroupList: map[string][]Action{},
}
}
// InOrg adds an org OwnerID to the resource
func (z Object) InOrg(orgID uuid.UUID) Object {
return Object{
Owner: z.Owner,
OrgID: orgID.String(),
Type: z.Type,
Owner: z.Owner,
OrgID: orgID.String(),
Type: z.Type,
ACLUserList: z.ACLUserList,
ACLGroupList: z.ACLGroupList,
}
}
// WithOwner adds an OwnerID to the resource
func (z Object) WithOwner(ownerID string) Object {
return Object{
Owner: ownerID,
OrgID: z.OrgID,
Type: z.Type,
Owner: ownerID,
OrgID: z.OrgID,
Type: z.Type,
ACLUserList: z.ACLUserList,
ACLGroupList: z.ACLGroupList,
}
}
// WithACLUserList adds an ACL list to a given object
func (z Object) WithACLUserList(acl map[string][]Action) Object {
return Object{
Owner: z.Owner,
OrgID: z.OrgID,
Type: z.Type,
ACLUserList: acl,
ACLGroupList: z.ACLGroupList,
}
}
func (z Object) WithGroupACL(groups map[string][]Action) Object {
return Object{
Owner: z.Owner,
OrgID: z.OrgID,
Type: z.Type,
ACLUserList: z.ACLUserList,
ACLGroupList: groups,
}
}

View File

@ -29,7 +29,7 @@ type PartialAuthorizer struct {
var _ PreparedAuthorized = (*PartialAuthorizer)(nil)
func (pa *PartialAuthorizer) Compile() (AuthorizeFilter, error) {
filter, err := Compile(pa.partialQueries)
filter, err := Compile(pa)
if err != nil {
return nil, xerrors.Errorf("compile: %w", err)
}
@ -99,7 +99,7 @@ EachQueryLoop:
// inspect this any further. But just in case, we will verify each expression
// did resolve to 'true'. This is purely defensive programming.
for _, exp := range results[0].Expressions {
if exp.String() != "true" {
if v, ok := exp.Value.(bool); !ok || !v {
continue EachQueryLoop
}
}
@ -110,15 +110,16 @@ EachQueryLoop:
return ForbiddenWithInternal(xerrors.Errorf("policy disallows request"), pa.input, nil)
}
func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, scope Role, action Action, objectType string) (*PartialAuthorizer, error) {
func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, scope Role, groups []string, action Action, objectType string) (*PartialAuthorizer, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
input := map[string]interface{}{
"subject": authSubject{
ID: subjectID,
Roles: roles,
Scope: scope,
ID: subjectID,
Roles: roles,
Scope: scope,
Groups: groups,
},
"object": map[string]string{
"type": objectType,
@ -129,11 +130,13 @@ func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, s
// Run the rego policy with a few unknown fields. This should simplify our
// policy to a set of queries.
partialQueries, err := rego.New(
rego.Query("data.authz.role_allow = true data.authz.scope_allow = true"),
rego.Query("data.authz.allow = true"),
rego.Module("policy.rego", policy),
rego.Unknowns([]string{
"input.object.owner",
"input.object.org_owner",
"input.object.acl_user_list",
"input.object.acl_group_list",
}),
rego.Input(input),
).Partial(ctx)

View File

@ -2,8 +2,8 @@ package authz
import future.keywords
# A great playground: https://play.openpolicyagent.org/
# Helpful cli commands to debug.
# opa eval --format=pretty 'data.authz.role_allow data.authz.scope_allow' -d policy.rego -i input.json
# opa eval --partial --format=pretty 'data.authz.role_allow = true data.authz.scope_allow = true' -d policy.rego --unknowns input.object.owner --unknowns input.object.org_owner -i input.json
# opa eval --format=pretty 'data.authz.allow' -d policy.rego -i input.json
# opa eval --partial --format=pretty 'data.authz.allow' -d policy.rego --unknowns input.object.owner --unknowns input.object.org_owner --unknowns input.object.acl_user_list --unknowns input.object.acl_group_list -i input.json
#
# This policy is specifically constructed to compress to a set of queries if the
@ -119,9 +119,13 @@ org_mem := true {
input.object.org_owner in org_members
}
org_ok {
org_mem
}
# If the object has no organization, then the user is also considered part of
# the non-existent org.
org_mem := true {
org_ok {
input.object.org_owner == ""
}
@ -156,7 +160,6 @@ user_allow(roles) := num {
# Allow query:
# data.authz.role_allow = true data.authz.scope_allow = true
default role_allow = false
role_allow {
site = 1
}
@ -171,12 +174,10 @@ role_allow {
not org = -1
# If we are not a member of an org, and the object has an org, then we are
# not authorized. This is an "implied -1" for not being in the org.
org_mem
org_ok
user = 1
}
default scope_allow = false
scope_allow {
scope_site = 1
}
@ -191,6 +192,48 @@ scope_allow {
not scope_org = -1
# If we are not a member of an org, and the object has an org, then we are
# not authorized. This is an "implied -1" for not being in the org.
org_mem
org_ok
scope_user = 1
}
# ACL for users
acl_allow {
# Should you have to be a member of the org too?
perms := input.object.acl_user_list[input.subject.id]
# Either the input action or wildcard
[input.action, "*"][_] in perms
}
# ACL for groups
acl_allow {
# If there is no organization owner, the object cannot be owned by an
# org_scoped team.
org_mem
group := input.subject.groups[_]
perms := input.object.acl_group_list[group]
# Either the input action or wildcard
[input.action, "*"][_] in perms
}
# ACL for 'all_users' special group
acl_allow {
org_mem
perms := input.object.acl_group_list[input.object.org_owner]
[input.action, "*"][_] in perms
}
###############
# Final Allow
# The role or the ACL must allow the action. Scopes can be used to limit,
# so scope_allow must always be true.
allow {
role_allow
scope_allow
}
# ACL list must also have the scope_allow to pass
allow {
acl_allow
scope_allow
}

View File

@ -1,13 +1,13 @@
package rbac
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"golang.org/x/xerrors"
)
@ -16,6 +16,9 @@ type TermType string
const (
VarTypeJsonbTextArray TermType = "jsonb-text-array"
VarTypeText TermType = "text"
VarTypeBoolean TermType = "boolean"
// VarTypeSkip means this variable does not exist to use.
VarTypeSkip TermType = "skip"
)
type SQLColumn struct {
@ -79,19 +82,54 @@ func DefaultConfig() SQLConfig {
}
}
func NoACLConfig() SQLConfig {
return SQLConfig{
Variables: []SQLColumn{
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_group_list\.?(.*)$`),
ColumnSelect: "",
Type: VarTypeSkip,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.acl_user_list\.?(.*)$`),
ColumnSelect: "",
Type: VarTypeSkip,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.org_owner$`),
ColumnSelect: "organization_id :: text",
Type: VarTypeText,
},
{
RegoMatch: regexp.MustCompile(`^input\.object\.owner$`),
ColumnSelect: "owner_id :: text",
Type: VarTypeText,
},
},
}
}
type AuthorizeFilter interface {
// RegoString is used in debugging to see the original rego expression.
RegoString() string
// SQLString returns the SQL expression that can be used in a WHERE clause.
SQLString(cfg SQLConfig) string
Expression
// Eval is required for the fake in memory database to work. The in memory
// database can use this function to filter the results.
Eval(object Object) bool
}
// expressionTop handles Eval(object Object) for in memory expressions
type expressionTop struct {
Expression
Auth *PartialAuthorizer
}
func (e expressionTop) Eval(object Object) bool {
return e.Auth.Authorize(context.Background(), object) == nil
}
// Compile will convert a rego query AST into our custom types. The output is
// an AST that can be used to generate SQL.
func Compile(partialQueries *rego.PartialQueries) (Expression, error) {
func Compile(pa *PartialAuthorizer) (AuthorizeFilter, error) {
partialQueries := pa.partialQueries
if len(partialQueries.Support) > 0 {
return nil, xerrors.Errorf("cannot convert support rules, expect 0 found %d", len(partialQueries.Support))
}
@ -128,11 +166,15 @@ func Compile(partialQueries *rego.PartialQueries) (Expression, error) {
}
builder.WriteString(partialQueries.Queries[i].String())
}
return expOr{
exp := expOr{
base: base{
Rego: builder.String(),
},
Expressions: result,
}
return expressionTop{
Expression: &exp,
Auth: pa,
}, nil
}
@ -218,21 +260,22 @@ func processTerms(expected int, terms []*ast.Term) ([]Term, error) {
}
func processTerm(term *ast.Term) (Term, error) {
base := base{Rego: term.String()}
termBase := base{Rego: term.String()}
switch v := term.Value.(type) {
case ast.Boolean:
return &termBoolean{
base: base,
base: termBase,
Value: bool(v),
}, nil
case ast.Ref:
obj := &termObject{
base: base,
Variables: []termVariable{},
base: termBase,
Path: []Term{},
}
var idx int
// A ref is a set of terms. If the first term is a var, then the
// following terms are the path to the value.
isRef := true
var builder strings.Builder
for _, term := range v {
if idx == 0 {
@ -241,15 +284,37 @@ func processTerm(term *ast.Term) (Term, error) {
}
}
if _, ok := term.Value.(ast.Ref); ok {
_, newRef := term.Value.(ast.Ref)
if newRef ||
// This is an unfortunate hack. To fix this, we need to rewrite
// our SQL config as a path ([]string{"input", "object", "acl_group"}).
// In the rego AST, there is no difference between selecting
// a field by a variable, and selecting a field by a literal (string).
// This was a misunderstanding.
// Example (these are equivalent by AST):
// input.object.acl_group_list['4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75']
// input.object.acl_group_list.organization_id
//
// This is not equivalent
// input.object.acl_group_list[input.object.organization_id]
//
// If this becomes even more hairy, we should fix the sql config.
builder.String() == "input.object.acl_group_list" ||
builder.String() == "input.object.acl_user_list" {
if !newRef {
isRef = false
}
// New obj
obj.Variables = append(obj.Variables, termVariable{
base: base,
obj.Path = append(obj.Path, termVariable{
base: base{
Rego: builder.String(),
},
Name: builder.String(),
})
builder.Reset()
idx = 0
}
if builder.Len() != 0 {
builder.WriteString(".")
}
@ -257,20 +322,31 @@ func processTerm(term *ast.Term) (Term, error) {
idx++
}
obj.Variables = append(obj.Variables, termVariable{
base: base,
Name: builder.String(),
})
if isRef {
obj.Path = append(obj.Path, termVariable{
base: base{
Rego: builder.String(),
},
Name: builder.String(),
})
} else {
obj.Path = append(obj.Path, termString{
base: base{
Rego: fmt.Sprintf("%q", builder.String()),
},
Value: builder.String(),
})
}
return obj, nil
case ast.Var:
return &termVariable{
Name: trimQuotes(v.String()),
base: base,
base: termBase,
}, nil
case ast.String:
return &termString{
Value: trimQuotes(v.String()),
base: base,
base: termBase,
}, nil
case ast.Set:
slice := v.Slice()
@ -285,7 +361,7 @@ func processTerm(term *ast.Term) (Term, error) {
return &termSet{
Value: set,
base: base,
base: termBase,
}, nil
default:
return nil, xerrors.Errorf("invalid term: %T not supported, %q", v, term.String())
@ -306,7 +382,10 @@ func (b base) RegoString() string {
//
// Eg: neq(input.object.org_owner, "") AND input.object.org_owner == "foo"
type Expression interface {
AuthorizeFilter
// RegoString is used in debugging to see the original rego expression.
RegoString() string
// SQLString returns the SQL expression that can be used in a WHERE clause.
SQLString(cfg SQLConfig) string
}
type expAnd struct {
@ -326,15 +405,6 @@ func (t expAnd) SQLString(cfg SQLConfig) string {
return "(" + strings.Join(exprs, " AND ") + ")"
}
func (t expAnd) Eval(object Object) bool {
for _, expr := range t.Expressions {
if !expr.Eval(object) {
return false
}
}
return true
}
type expOr struct {
base
Expressions []Expression
@ -352,15 +422,6 @@ func (t expOr) SQLString(cfg SQLConfig) string {
return "(" + strings.Join(exprs, " OR ") + ")"
}
func (t expOr) Eval(object Object) bool {
for _, expr := range t.Expressions {
if expr.Eval(object) {
return true
}
}
return false
}
// Operator joins terms together to form an expression.
// Operators are also expressions.
//
@ -384,14 +445,6 @@ func (t opEqual) SQLString(cfg SQLConfig) string {
return fmt.Sprintf("%s %s %s", t.Terms[0].SQLString(cfg), op, t.Terms[1].SQLString(cfg))
}
func (t opEqual) Eval(object Object) bool {
a, b := t.Terms[0].EvalTerm(object), t.Terms[1].EvalTerm(object)
if t.Not {
return a != b
}
return a == b
}
// opInternalMember2 is checking if the first term is a member of the second term.
// The second term is a set or list.
type opInternalMember2 struct {
@ -400,20 +453,6 @@ type opInternalMember2 struct {
Haystack Term
}
func (t opInternalMember2) Eval(object Object) bool {
a, b := t.Needle.EvalTerm(object), t.Haystack.EvalTerm(object)
bset, ok := b.([]interface{})
if !ok {
return false
}
for _, elem := range bset {
if a == elem {
return true
}
}
return false
}
func (t opInternalMember2) SQLString(cfg SQLConfig) string {
if haystack, ok := t.Haystack.(*termObject); ok {
// This is a special case where the haystack is a jsonb array.
@ -425,9 +464,14 @@ func (t opInternalMember2) SQLString(cfg SQLConfig) string {
// having to add more "if" branches here.
// But until we need more cases, our basic type system is ok, and
// this is the only case we need to handle.
if haystack.SQLType(cfg) == VarTypeJsonbTextArray {
sqlType := haystack.SQLType(cfg)
if sqlType == VarTypeJsonbTextArray {
return fmt.Sprintf("%s ? %s", haystack.SQLString(cfg), t.Needle.SQLString(cfg))
}
if sqlType == VarTypeSkip {
return "true"
}
}
return fmt.Sprintf("%s = ANY(%s)", t.Needle.SQLString(cfg), t.Haystack.SQLString(cfg))
@ -440,9 +484,7 @@ func (t opInternalMember2) SQLString(cfg SQLConfig) string {
type Term interface {
RegoString() string
SQLString(cfg SQLConfig) string
// Eval will evaluate the term
// Terms can eval to any type. The operator/expression will type check.
EvalTerm(object Object) interface{}
SQLType(cfg SQLConfig) TermType
}
type termString struct {
@ -450,10 +492,6 @@ type termString struct {
Value string
}
func (t termString) EvalTerm(_ Object) interface{} {
return t.Value
}
func (t termString) SQLString(_ SQLConfig) string {
return "'" + t.Value + "'"
}
@ -471,14 +509,7 @@ func (termString) SQLType(_ SQLConfig) TermType {
// term type.
type termObject struct {
base
Variables []termVariable
}
func (t termObject) EvalTerm(obj Object) interface{} {
if len(t.Variables) == 0 {
return t.Variables[0].EvalTerm(obj)
}
panic("no nested structures are supported yet")
Path []Term
}
func (t termObject) SQLType(cfg SQLConfig) TermType {
@ -486,30 +517,30 @@ func (t termObject) SQLType(cfg SQLConfig) TermType {
// is the resulting type. This is correct for our use case.
// Solving this more generally requires a full type system, which is
// excessive for our mostly static policy.
return t.Variables[0].SQLType(cfg)
return t.Path[0].SQLType(cfg)
}
func (t termObject) SQLString(cfg SQLConfig) string {
if len(t.Variables) == 1 {
return t.Variables[0].SQLString(cfg)
if len(t.Path) == 1 {
return t.Path[0].SQLString(cfg)
}
// Combine the last 2 variables into 1 variable.
end := t.Variables[len(t.Variables)-1]
before := t.Variables[len(t.Variables)-2]
end := t.Path[len(t.Path)-1]
before := t.Path[len(t.Path)-2]
// Recursively solve the SQLString by removing the last nested reference.
// This continues until we have a single variable.
return termObject{
base: t.base,
Variables: append(
t.Variables[:len(t.Variables)-2],
Path: append(
t.Path[:len(t.Path)-2],
termVariable{
base: base{
Rego: before.base.Rego + "[" + end.base.Rego + "]",
Rego: before.RegoString() + "[" + end.RegoString() + "]",
},
// Convert the end to SQL string. We evaluate each term
// one at a time.
Name: before.Name + "." + end.SQLString(cfg),
Name: before.RegoString() + "." + end.SQLString(cfg),
},
),
}.SQLString(cfg)
@ -520,19 +551,6 @@ type termVariable struct {
Name string
}
func (t termVariable) EvalTerm(obj Object) interface{} {
switch t.Name {
case "input.object.org_owner":
return obj.OrgID
case "input.object.owner":
return obj.Owner
case "input.object.type":
return obj.Type
default:
return fmt.Sprintf("'Unknown variable %s'", t.Name)
}
}
func (t termVariable) SQLType(cfg SQLConfig) TermType {
if col := t.ColumnConfig(cfg); col != nil {
return col.Type
@ -576,13 +594,15 @@ type termSet struct {
Value []Term
}
func (t termSet) EvalTerm(obj Object) interface{} {
set := make([]interface{}, 0, len(t.Value))
for _, term := range t.Value {
set = append(set, term.EvalTerm(obj))
func (t termSet) SQLType(cfg SQLConfig) TermType {
if len(t.Value) == 0 {
return VarTypeText
}
return set
// Without a full type system, let's just assume the type of the first var
// is the resulting type. This is correct for our use case.
// Solving this more generally requires a full type system, which is
// excessive for our mostly static policy.
return t.Value[0].SQLType(cfg)
}
func (t termSet) SQLString(cfg SQLConfig) string {
@ -599,11 +619,11 @@ type termBoolean struct {
Value bool
}
func (t termBoolean) Eval(_ Object) bool {
return t.Value
func (termBoolean) SQLType(SQLConfig) TermType {
return VarTypeBoolean
}
func (t termBoolean) EvalTerm(_ Object) interface{} {
func (t termBoolean) Eval(_ Object) bool {
return t.Value
}

View File

@ -1,6 +1,7 @@
package rbac
import (
"context"
"testing"
"github.com/open-policy-agent/opa/ast"
@ -11,17 +12,10 @@ import (
func TestCompileQuery(t *testing.T) {
t.Parallel()
opts := ast.ParserOptions{
AllFutureKeywords: true,
}
t.Run("EmptyQuery", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
must(ast.ParseBody("")),
},
Support: []*ast.Module{},
})
expression, err := Compile(partialQueries(t, ""))
require.NoError(t, err, "compile empty")
require.Equal(t, "true", expression.RegoString(), "empty query is rego 'true'")
@ -30,12 +24,7 @@ func TestCompileQuery(t *testing.T) {
t.Run("TrueQuery", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
must(ast.ParseBody("true")),
},
Support: []*ast.Module{},
})
expression, err := Compile(partialQueries(t, "true"))
require.NoError(t, err, "compile")
require.Equal(t, "true", expression.RegoString(), "true query is rego 'true'")
@ -44,49 +33,118 @@ func TestCompileQuery(t *testing.T) {
t.Run("ACLIn", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list.allUsers`, opts),
},
Support: []*ast.Module{},
})
expression, err := Compile(partialQueries(t, `"*" in input.object.acl_group_list.allUsers`))
require.NoError(t, err, "compile")
require.Equal(t, `internal.member_2("*", input.object.acl_group_list.allUsers)`, expression.RegoString(), "convert to internal_member")
require.Equal(t, `group_acl->allUsers ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in")
require.Equal(t, `group_acl->'allUsers' ? '*'`, expression.SQLString(DefaultConfig()), "jsonb in")
})
t.Run("Complex", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts),
ast.MustParseBodyWithOpts(`input.object.org_owner in {"a", "b", "c"}`, opts),
ast.MustParseBodyWithOpts(`input.object.org_owner != ""`, opts),
ast.MustParseBodyWithOpts(`"read" in input.object.acl_group_list.allUsers`, opts),
ast.MustParseBodyWithOpts(`"read" in input.object.acl_user_list.me`, opts),
},
Support: []*ast.Module{},
})
expression, err := Compile(partialQueries(t,
`input.object.org_owner != ""`,
`input.object.org_owner in {"a", "b", "c"}`,
`input.object.org_owner != ""`,
`"read" in input.object.acl_group_list.allUsers`,
`"read" in input.object.acl_user_list.me`,
))
require.NoError(t, err, "compile")
require.Equal(t, `(organization_id :: text != '' OR `+
`organization_id :: text = ANY(ARRAY ['a','b','c']) OR `+
`organization_id :: text != '' OR `+
`group_acl->allUsers ? 'read' OR `+
`user_acl->me ? 'read')`,
`group_acl->'allUsers' ? 'read' OR `+
`user_acl->'me' ? 'read')`,
expression.SQLString(DefaultConfig()), "complex")
})
t.Run("SetDereference", func(t *testing.T) {
t.Parallel()
expression, err := Compile(&rego.PartialQueries{
Queries: []ast.Body{
ast.MustParseBodyWithOpts(`"*" in input.object.acl_group_list[input.object.org_owner]`, opts),
},
Support: []*ast.Module{},
})
expression, err := Compile(partialQueries(t,
`"*" in input.object.acl_group_list[input.object.org_owner]`,
))
require.NoError(t, err, "compile")
require.Equal(t, `group_acl->organization_id :: text ? '*'`,
expression.SQLString(DefaultConfig()), "set dereference")
})
t.Run("JsonbLiteralDereference", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
))
require.NoError(t, err, "compile")
require.Equal(t, `group_acl->'4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75' ? '*'`,
expression.SQLString(DefaultConfig()), "literal dereference")
})
t.Run("NoACLColumns", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`"*" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
))
require.NoError(t, err, "compile")
require.Equal(t, `true`,
expression.SQLString(NoACLConfig()), "literal dereference")
})
}
func TestEvalQuery(t *testing.T) {
t.Parallel()
t.Run("GroupACL", func(t *testing.T) {
t.Parallel()
expression, err := Compile(partialQueries(t,
`"read" in input.object.acl_group_list["4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75"]`,
))
require.NoError(t, err, "compile")
result := expression.Eval(Object{
Owner: "not-me",
OrgID: "random",
Type: "workspace",
ACLUserList: map[string][]Action{},
ACLGroupList: map[string][]Action{
"4d30d4a8-b87d-45ac-b0d4-51b2e68e7e75": {"read"},
},
})
require.True(t, result, "eval")
})
}
func partialQueries(t *testing.T, queries ...string) *PartialAuthorizer {
opts := ast.ParserOptions{
AllFutureKeywords: true,
}
astQueries := make([]ast.Body, 0, len(queries))
for _, q := range queries {
astQueries = append(astQueries, ast.MustParseBodyWithOpts(q, opts))
}
prepareQueries := make([]rego.PreparedEvalQuery, 0, len(queries))
for _, q := range astQueries {
var prepped rego.PreparedEvalQuery
var err error
if q.String() == "" {
prepped, err = rego.New(
rego.Query("true"),
).PrepareForEval(context.Background())
} else {
prepped, err = rego.New(
rego.ParsedQuery(q),
).PrepareForEval(context.Background())
}
require.NoError(t, err, "prepare query")
prepareQueries = append(prepareQueries, prepped)
}
return &PartialAuthorizer{
partialQueries: &rego.PartialQueries{
Queries: astQueries,
Support: []*ast.Module{},
},
preparedQueries: prepareQueries,
input: nil,
alwaysTrue: false,
}
}

View File

@ -19,8 +19,8 @@ var builtinScopes map[Scope]Role = map[Scope]Role{
ScopeAll: {
Name: fmt.Sprintf("Scope_%s", ScopeAll),
DisplayName: "All operations",
Site: permissions(map[Object][]Action{
ResourceWildcard: {WildcardSymbol},
Site: permissions(map[string][]Action{
ResourceWildcard.Type: {WildcardSymbol},
}),
Org: map[string][]Permission{},
User: []Permission{},
@ -29,8 +29,8 @@ var builtinScopes map[Scope]Role = map[Scope]Role{
ScopeApplicationConnect: {
Name: fmt.Sprintf("Scope_%s", ScopeApplicationConnect),
DisplayName: "Ability to connect to applications",
Site: permissions(map[Object][]Action{
ResourceWorkspaceApplicationConnect: {ActionCreate},
Site: permissions(map[string][]Action{
ResourceWorkspaceApplicationConnect.Type: {ActionCreate},
}),
Org: map[string][]Permission{},
User: []Permission{},

View File

@ -61,11 +61,6 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) {
return
}
if !api.Authorize(r, rbac.ActionRead, template) {
httpapi.ResourceNotFound(rw)
return
}
count := uint32(0)
if len(workspaceCounts) > 0 {
count = uint32(workspaceCounts[0].Count)
@ -248,9 +243,9 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
var dbTemplate database.Template
var template codersdk.Template
err = api.Database.InTx(func(db database.Store) error {
err = api.Database.InTx(func(tx database.Store) error {
now := database.Now()
dbTemplate, err = db.InsertTemplate(ctx, database.InsertTemplateParams{
dbTemplate, err = tx.InsertTemplate(ctx, database.InsertTemplateParams{
ID: uuid.New(),
CreatedAt: now,
UpdatedAt: now,
@ -269,7 +264,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
templateAudit.New = dbTemplate
err = db.UpdateTemplateVersionByID(ctx, database.UpdateTemplateVersionByIDParams{
err = tx.UpdateTemplateVersionByID(ctx, database.UpdateTemplateVersionByIDParams{
ID: templateVersion.ID,
TemplateID: uuid.NullUUID{
UUID: dbTemplate.ID,
@ -288,7 +283,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
templateVersionAudit.New = newTemplateVersion
for _, parameterValue := range createTemplate.ParameterValues {
_, err = db.InsertParameterValue(ctx, database.InsertParameterValueParams{
_, err = tx.InsertParameterValue(ctx, database.InsertParameterValueParams{
ID: uuid.New(),
Name: parameterValue.Name,
CreatedAt: database.Now(),
@ -304,7 +299,14 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
}
}
createdByNameMap, err := getCreatedByNamesByTemplateIDs(ctx, db, []database.Template{dbTemplate})
err = tx.UpdateTemplateGroupACLByID(ctx, dbTemplate.ID, database.TemplateACL{
dbTemplate.OrganizationID.String(): []rbac.Action{rbac.ActionRead},
})
if err != nil {
return xerrors.Errorf("update template group acl: %w", err)
}
createdByNameMap, err := getCreatedByNamesByTemplateIDs(ctx, tx, []database.Template{dbTemplate})
if err != nil {
return xerrors.Errorf("get creator name: %w", err)
}
@ -472,13 +474,7 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
validErrs = append(validErrs, codersdk.ValidationError{Field: "min_autostart_interval_ms", Detail: "Must be a positive integer."})
}
if req.MaxTTLMillis > maxTTLDefault.Milliseconds() {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid create template request.",
Validations: []codersdk.ValidationError{
{Field: "max_ttl_ms", Detail: "Cannot be greater than " + maxTTLDefault.String()},
},
})
return
validErrs = append(validErrs, codersdk.ValidationError{Field: "max_ttl_ms", Detail: "Cannot be greater than " + maxTTLDefault.String()})
}
if len(validErrs) > 0 {
@ -491,9 +487,9 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
count := uint32(0)
var updated database.Template
err := api.Database.InTx(func(s database.Store) error {
err := api.Database.InTx(func(tx database.Store) error {
// Fetch workspace counts
workspaceCounts, err := s.GetWorkspaceOwnerCountsByTemplateIDs(ctx, []uuid.UUID{template.ID})
workspaceCounts, err := tx.GetWorkspaceOwnerCountsByTemplateIDs(ctx, []uuid.UUID{template.ID})
if xerrors.Is(err, sql.ErrNoRows) {
err = nil
}
@ -530,7 +526,7 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
minAutostartInterval = time.Duration(template.MinAutostartInterval)
}
updated, err = s.UpdateTemplateMetaByID(ctx, database.UpdateTemplateMetaByIDParams{
updated, err = tx.UpdateTemplateMetaByID(ctx, database.UpdateTemplateMetaByIDParams{
ID: template.ID,
UpdatedAt: database.Now(),
Name: name,
@ -597,13 +593,13 @@ type autoImportTemplateOpts struct {
func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateOpts) (database.Template, error) {
var template database.Template
err := api.Database.InTx(func(s database.Store) error {
err := api.Database.InTx(func(tx database.Store) error {
// Insert the archive into the files table.
var (
hash = sha256.Sum256(opts.archive)
now = database.Now()
)
file, err := s.InsertFile(ctx, database.InsertFileParams{
file, err := tx.InsertFile(ctx, database.InsertFileParams{
Hash: hex.EncodeToString(hash[:]),
CreatedAt: now,
CreatedBy: opts.userID,
@ -618,7 +614,7 @@ func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateO
// Insert parameters
for key, value := range opts.params {
_, err = s.InsertParameterValue(ctx, database.InsertParameterValueParams{
_, err = tx.InsertParameterValue(ctx, database.InsertParameterValueParams{
ID: uuid.New(),
Name: key,
CreatedAt: now,
@ -635,7 +631,7 @@ func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateO
}
// Create provisioner job
job, err := s.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
job, err := tx.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: jobID,
CreatedAt: now,
UpdatedAt: now,
@ -652,7 +648,7 @@ func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateO
}
// Create template version
templateVersion, err := s.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
templateVersion, err := tx.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: uuid.New(),
TemplateID: uuid.NullUUID{
UUID: uuid.Nil,
@ -674,7 +670,7 @@ func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateO
}
// Create template
template, err = s.InsertTemplate(ctx, database.InsertTemplateParams{
template, err = tx.InsertTemplate(ctx, database.InsertTemplateParams{
ID: uuid.New(),
CreatedAt: now,
UpdatedAt: now,
@ -692,7 +688,7 @@ func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateO
}
// Update template version with template ID
err = s.UpdateTemplateVersionByID(ctx, database.UpdateTemplateVersionByIDParams{
err = tx.UpdateTemplateVersionByID(ctx, database.UpdateTemplateVersionByIDParams{
ID: templateVersion.ID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
@ -705,7 +701,7 @@ func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateO
// Insert parameters at the template scope
for key, value := range opts.params {
_, err = s.InsertParameterValue(ctx, database.InsertParameterValueParams{
_, err = tx.InsertParameterValue(ctx, database.InsertParameterValueParams{
ID: uuid.New(),
Name: key,
CreatedAt: now,
@ -721,6 +717,13 @@ func (api *API) autoImportTemplate(ctx context.Context, opts autoImportTemplateO
}
}
err = tx.UpdateTemplateGroupACLByID(ctx, template.ID, database.TemplateACL{
opts.orgID.String(): []rbac.Action{rbac.ActionRead},
})
if err != nil {
return xerrors.Errorf("update template group acl: %w", err)
}
return nil
})

View File

@ -24,8 +24,12 @@ import (
func (api *API) templateVersion(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersion := httpmw.TemplateVersionParam(r)
if !api.Authorize(r, rbac.ActionRead, templateVersion) {
var (
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
)
if !api.Authorize(r, rbac.ActionRead, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return
}
@ -53,8 +57,11 @@ func (api *API) templateVersion(rw http.ResponseWriter, r *http.Request) {
func (api *API) patchCancelTemplateVersion(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersion := httpmw.TemplateVersionParam(r)
if !api.Authorize(r, rbac.ActionUpdate, templateVersion) {
var (
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
)
if !api.Authorize(r, rbac.ActionUpdate, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return
}
@ -105,8 +112,12 @@ func (api *API) patchCancelTemplateVersion(rw http.ResponseWriter, r *http.Reque
func (api *API) templateVersionSchema(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersion := httpmw.TemplateVersionParam(r)
if !api.Authorize(r, rbac.ActionRead, templateVersion) {
var (
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
)
if !api.Authorize(r, rbac.ActionRead, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return
}
@ -153,8 +164,11 @@ func (api *API) templateVersionSchema(rw http.ResponseWriter, r *http.Request) {
func (api *API) templateVersionParameters(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersion := httpmw.TemplateVersionParam(r)
if !api.Authorize(r, rbac.ActionRead, templateVersion) {
var (
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
)
if !api.Authorize(r, rbac.ActionRead, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return
}
@ -195,9 +209,12 @@ func (api *API) templateVersionParameters(rw http.ResponseWriter, r *http.Reques
func (api *API) postTemplateVersionDryRun(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
templateVersion := httpmw.TemplateVersionParam(r)
if !api.Authorize(r, rbac.ActionRead, templateVersion) {
var (
apiKey = httpmw.APIKey(r)
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
)
if !api.Authorize(r, rbac.ActionRead, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return
}
@ -367,9 +384,11 @@ func (api *API) fetchTemplateVersionDryRunJob(rw http.ResponseWriter, r *http.Re
var (
ctx = r.Context()
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
jobID = chi.URLParam(r, "jobID")
)
if !api.Authorize(r, rbac.ActionRead, templateVersion) {
if !api.Authorize(r, rbac.ActionRead, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return database.ProvisionerJob{}, false
}
@ -667,14 +686,10 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
return
}
// Making a new template version is the same permission as creating a new template.
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceTemplate.InOrg(organization.ID)) {
httpapi.ResourceNotFound(rw)
return
}
var template database.Template
if req.TemplateID != uuid.Nil {
_, err := api.Database.GetTemplateByID(ctx, req.TemplateID)
var err error
template, err = api.Database.GetTemplateByID(ctx, req.TemplateID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Template does not exist.",
@ -690,6 +705,17 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
}
}
if template.ID != uuid.Nil {
if !api.Authorize(r, rbac.ActionCreate, template) {
httpapi.ResourceNotFound(rw)
return
}
} else if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceTemplate.InOrg(organization.ID)) {
// Making a new template version is the same permission as creating a new template.
httpapi.ResourceNotFound(rw)
return
}
file, err := api.Database.GetFileByHash(ctx, req.StorageSource)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
@ -705,14 +731,16 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
return
}
if !api.Authorize(r, rbac.ActionRead, file) {
httpapi.ResourceNotFound(rw)
return
}
// TODO(JonA): Readd this check once we update the unique constraint
// on files to be owner + hash.
// if !api.Authorize(r, rbac.ActionRead, file) {
// httpapi.ResourceNotFound(rw)
// return
// }
var templateVersion database.TemplateVersion
var provisionerJob database.ProvisionerJob
err = api.Database.InTx(func(db database.Store) error {
err = api.Database.InTx(func(tx database.Store) error {
jobID := uuid.New()
inherits := make([]uuid.UUID, 0)
for _, parameterValue := range req.ParameterValues {
@ -727,7 +755,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
return xerrors.Errorf("cannot inherit parameters if template_id is not set")
}
inheritedParams, err := db.ParameterValues(ctx, database.ParameterValuesParams{
inheritedParams, err := tx.ParameterValues(ctx, database.ParameterValuesParams{
IDs: inherits,
})
if err != nil {
@ -736,7 +764,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
for _, copy := range inheritedParams {
// This is a bit inefficient, as we make a new db call for each
// param.
version, err := db.GetTemplateVersionByJobID(ctx, copy.ScopeID)
version, err := tx.GetTemplateVersionByJobID(ctx, copy.ScopeID)
if err != nil {
return xerrors.Errorf("fetch template version for param %q: %w", copy.Name, err)
}
@ -761,7 +789,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
continue
}
_, err = db.InsertParameterValue(ctx, database.InsertParameterValueParams{
_, err = tx.InsertParameterValue(ctx, database.InsertParameterValueParams{
ID: uuid.New(),
Name: parameterValue.Name,
CreatedAt: database.Now(),
@ -777,7 +805,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
}
}
provisionerJob, err = db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
provisionerJob, err = tx.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: jobID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
@ -805,7 +833,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
req.Name = namesgenerator.GetRandomName(1)
}
templateVersion, err = db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
templateVersion, err = tx.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: uuid.New(),
TemplateID: templateID,
OrganizationID: organization.ID,
@ -851,8 +879,12 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
// return agents associated with any particular workspace.
func (api *API) templateVersionResources(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersion := httpmw.TemplateVersionParam(r)
if !api.Authorize(r, rbac.ActionRead, templateVersion) {
var (
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
)
if !api.Authorize(r, rbac.ActionRead, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return
}
@ -874,8 +906,12 @@ func (api *API) templateVersionResources(rw http.ResponseWriter, r *http.Request
// Eg: Logs returned from 'terraform plan' when uploading a new terraform file.
func (api *API) templateVersionLogs(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
templateVersion := httpmw.TemplateVersionParam(r)
if !api.Authorize(r, rbac.ActionRead, templateVersion) {
var (
templateVersion = httpmw.TemplateVersionParam(r)
template = httpmw.TemplateParam(r)
)
if !api.Authorize(r, rbac.ActionRead, templateVersion.RBACObject(template)) {
httpapi.ResourceNotFound(rw)
return
}

View File

@ -34,6 +34,22 @@ func TestTemplateVersion(t *testing.T) {
_, err := client.TemplateVersion(ctx, version.ID)
require.NoError(t, err)
})
t.Run("MemberCanRead", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
_ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
ctx, _ := testutil.Context(t)
client1, _ := coderdtest.CreateAnotherUserWithUser(t, client, user.OrganizationID)
_, err := client1.TemplateVersion(ctx, version.ID)
require.NoError(t, err)
})
}
func TestPostTemplateVersionsByOrganization(t *testing.T) {

View File

@ -1032,6 +1032,11 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
}
req.OrganizationID = organization.ID
orgRoles = append(orgRoles, rbac.RoleOrgAdmin(req.OrganizationID))
_, err = tx.InsertAllUsersGroup(ctx, organization.ID)
if err != nil {
return xerrors.Errorf("create %q group: %w", database.AllUsersGroup, err)
}
}
params := database.InsertUserParams{

View File

@ -360,7 +360,6 @@ func TestWorkspaceApplicationAuth(t *testing.T) {
ResourceType: "application_connect",
OwnerID: "me",
OrganizationID: firstUser.OrganizationID.String(),
ResourceID: uuid.NewString(),
},
Action: "create",
},