mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
feat: add template RBAC/groups (#4235)
This commit is contained in:
@ -12,23 +12,30 @@ import (
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/util/slice"
|
||||
)
|
||||
|
||||
var errDuplicateKey = &pq.Error{
|
||||
Code: "23505",
|
||||
Message: "duplicate key value violates unique constraint",
|
||||
}
|
||||
|
||||
// New returns an in-memory fake of the database.
|
||||
func New() database.Store {
|
||||
return &fakeQuerier{
|
||||
mutex: &sync.RWMutex{},
|
||||
data: &data{
|
||||
apiKeys: make([]database.APIKey, 0),
|
||||
agentStats: make([]database.AgentStat, 0),
|
||||
organizationMembers: make([]database.OrganizationMember, 0),
|
||||
organizations: make([]database.Organization, 0),
|
||||
users: make([]database.User, 0),
|
||||
|
||||
apiKeys: make([]database.APIKey, 0),
|
||||
agentStats: make([]database.AgentStat, 0),
|
||||
organizationMembers: make([]database.OrganizationMember, 0),
|
||||
organizations: make([]database.Organization, 0),
|
||||
users: make([]database.User, 0),
|
||||
groups: make([]database.Group, 0),
|
||||
groupMembers: make([]database.GroupMember, 0),
|
||||
auditLogs: make([]database.AuditLog, 0),
|
||||
files: make([]database.File, 0),
|
||||
gitSSHKey: make([]database.GitSSHKey, 0),
|
||||
@ -84,6 +91,8 @@ type data struct {
|
||||
auditLogs []database.AuditLog
|
||||
files []database.File
|
||||
gitSSHKey []database.GitSSHKey
|
||||
groups []database.Group
|
||||
groupMembers []database.GroupMember
|
||||
parameterSchemas []database.ParameterSchema
|
||||
parameterValues []database.ParameterValue
|
||||
provisionerDaemons []database.ProvisionerDaemon
|
||||
@ -518,6 +527,13 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
|
||||
}
|
||||
}
|
||||
|
||||
var groups []string
|
||||
for _, member := range q.groupMembers {
|
||||
if member.UserID == userID {
|
||||
groups = append(groups, member.GroupID.String())
|
||||
}
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return database.GetAuthorizationUserRolesRow{}, sql.ErrNoRows
|
||||
}
|
||||
@ -527,6 +543,7 @@ func (q *fakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
|
||||
Username: user.Username,
|
||||
Status: user.Status,
|
||||
Roles: roles,
|
||||
Groups: groups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -1269,6 +1286,116 @@ func (q *fakeQuerier) GetTemplates(_ context.Context) ([]database.Template, erro
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateTemplateUserACLByID(_ context.Context, id uuid.UUID, acl database.TemplateACL) error {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for i, t := range q.templates {
|
||||
if t.ID == id {
|
||||
t = t.SetUserACL(acl)
|
||||
q.templates[i] = t
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateTemplateGroupACLByID(_ context.Context, id uuid.UUID, acl database.TemplateACL) error {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for i, t := range q.templates {
|
||||
if t.ID == id {
|
||||
t = t.SetGroupACL(acl)
|
||||
q.templates[i] = t
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var template database.Template
|
||||
for _, t := range q.templates {
|
||||
if t.ID == id {
|
||||
template = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if template.ID == uuid.Nil {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
acl := template.UserACL()
|
||||
|
||||
users := make([]database.TemplateUser, 0, len(acl))
|
||||
for k, v := range acl {
|
||||
user, err := q.GetUserByID(context.Background(), uuid.MustParse(k))
|
||||
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get user by ID: %w", err)
|
||||
}
|
||||
// We don't delete users from the map if they
|
||||
// get deleted so just skip.
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
|
||||
if user.Deleted || user.Status == database.UserStatusSuspended {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, database.TemplateUser{
|
||||
User: user,
|
||||
Actions: v,
|
||||
})
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var template database.Template
|
||||
for _, t := range q.templates {
|
||||
if t.ID == id {
|
||||
template = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if template.ID == uuid.Nil {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
acl := template.GroupACL()
|
||||
|
||||
groups := make([]database.TemplateGroup, 0, len(acl))
|
||||
for k, v := range acl {
|
||||
group, err := q.GetGroupByID(context.Background(), uuid.MustParse(k))
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get group by ID: %w", err)
|
||||
}
|
||||
// We don't delete groups from the map if they
|
||||
// get deleted so just skip.
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
|
||||
groups = append(groups, database.TemplateGroup{
|
||||
Group: group,
|
||||
Actions: v,
|
||||
})
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
@ -1749,6 +1876,10 @@ func (q *fakeQuerier) InsertTemplate(_ context.Context, arg database.InsertTempl
|
||||
MinAutostartInterval: arg.MinAutostartInterval,
|
||||
CreatedBy: arg.CreatedBy,
|
||||
}
|
||||
template = template.SetUserACL(database.TemplateACL{})
|
||||
template = template.SetGroupACL(database.TemplateACL{
|
||||
arg.OrganizationID.String(): []rbac.Action{rbac.ActionRead},
|
||||
})
|
||||
q.templates = append(q.templates, template)
|
||||
return template, nil
|
||||
}
|
||||
@ -2299,7 +2430,7 @@ func (q *fakeQuerier) UpdateWorkspace(_ context.Context, arg database.UpdateWork
|
||||
continue
|
||||
}
|
||||
if other.Name == arg.Name {
|
||||
return database.Workspace{}, &pq.Error{Code: "23505", Message: "duplicate key value violates unique constraint"}
|
||||
return database.Workspace{}, errDuplicateKey
|
||||
}
|
||||
}
|
||||
|
||||
@ -2437,6 +2568,52 @@ func (q *fakeQuerier) UpdateGitSSHKey(_ context.Context, arg database.UpdateGitS
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGroupMemberParams) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for _, member := range q.groupMembers {
|
||||
if member.GroupID == arg.GroupID &&
|
||||
member.UserID == arg.UserID {
|
||||
return errDuplicateKey
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gosimple
|
||||
q.groupMembers = append(q.groupMembers, database.GroupMember{
|
||||
GroupID: arg.GroupID,
|
||||
UserID: arg.UserID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) DeleteGroupMember(_ context.Context, userID uuid.UUID) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for i, member := range q.groupMembers {
|
||||
if member.UserID == userID {
|
||||
q.groupMembers = append(q.groupMembers[:i], q.groupMembers[i+1:]...)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for i, group := range q.groups {
|
||||
if group.ID == arg.ID {
|
||||
group.Name = arg.Name
|
||||
q.groups[i] = group
|
||||
return group, nil
|
||||
}
|
||||
}
|
||||
return database.Group{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
@ -2714,3 +2891,137 @@ func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
|
||||
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetGroupByID(_ context.Context, id uuid.UUID) (database.Group, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, group := range q.groups {
|
||||
if group.ID == id {
|
||||
return group, nil
|
||||
}
|
||||
}
|
||||
|
||||
return database.Group{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGroupByOrgAndNameParams) (database.Group, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, group := range q.groups {
|
||||
if group.OrganizationID == arg.OrganizationID &&
|
||||
group.Name == arg.Name {
|
||||
return group, nil
|
||||
}
|
||||
}
|
||||
|
||||
return database.Group{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) {
|
||||
return q.InsertGroup(ctx, database.InsertGroupParams{
|
||||
ID: orgID,
|
||||
Name: database.AllUsersGroup,
|
||||
OrganizationID: orgID,
|
||||
})
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupParams) (database.Group, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, group := range q.groups {
|
||||
if group.OrganizationID.String() == arg.OrganizationID.String() &&
|
||||
group.Name == arg.Name {
|
||||
return database.Group{}, errDuplicateKey
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gosimple
|
||||
group := database.Group{
|
||||
ID: arg.ID,
|
||||
Name: arg.Name,
|
||||
OrganizationID: arg.OrganizationID,
|
||||
}
|
||||
|
||||
q.groups = append(q.groups, group)
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (*fakeQuerier) GetUserGroups(_ context.Context, _ uuid.UUID) ([]database.Group, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetGroupMembers(_ context.Context, groupID uuid.UUID) ([]database.User, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var members []database.GroupMember
|
||||
for _, member := range q.groupMembers {
|
||||
if member.GroupID == groupID {
|
||||
members = append(members, member)
|
||||
}
|
||||
}
|
||||
|
||||
users := make([]database.User, 0, len(members))
|
||||
|
||||
for _, member := range members {
|
||||
for _, user := range q.users {
|
||||
if user.ID == member.UserID && user.Status == database.UserStatusActive && !user.Deleted {
|
||||
users = append(users, user)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationID uuid.UUID) ([]database.Group, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var groups []database.Group
|
||||
for _, group := range q.groups {
|
||||
// Omit the allUsers group.
|
||||
if group.OrganizationID == organizationID && group.ID != organizationID {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAllOrganizationMembers(_ context.Context, organizationID uuid.UUID) ([]database.User, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var users []database.User
|
||||
for _, member := range q.organizationMembers {
|
||||
if member.OrganizationID == organizationID {
|
||||
for _, user := range q.users {
|
||||
if user.ID == member.UserID {
|
||||
users = append(users, user)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
for i, group := range q.groups {
|
||||
if group.ID == id {
|
||||
q.groups = append(q.groups[:i], q.groups[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
Reference in New Issue
Block a user