feat: expose Everyone group through UI (#9117)

- Allows setting quota allowances on the 'Everyone' group.
This commit is contained in:
Jon Ayers
2023-08-17 13:25:16 -05:00
committed by GitHub
parent 8910f05172
commit 2f6687a475
23 changed files with 458 additions and 80 deletions

View File

@ -916,11 +916,11 @@ func (q *querier) GetGroupByOrgAndName(ctx context.Context, arg database.GetGrou
return fetch(q.log, q.auth, q.db.GetGroupByOrgAndName)(ctx, arg)
}
func (q *querier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]database.User, error) {
if _, err := q.GetGroupByID(ctx, groupID); err != nil { // AuthZ check
func (q *querier) GetGroupMembers(ctx context.Context, id uuid.UUID) ([]database.User, error) {
if _, err := q.GetGroupByID(ctx, id); err != nil { // AuthZ check
return nil, err
}
return q.db.GetGroupMembers(ctx, groupID)
return q.db.GetGroupMembers(ctx, id)
}
func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) {

View File

@ -613,6 +613,44 @@ func uniqueSortedUUIDs(uuids []uuid.UUID) []uuid.UUID {
return unique
}
func (q *FakeQuerier) getOrganizationMember(orgID uuid.UUID) []database.OrganizationMember {
var members []database.OrganizationMember
for _, member := range q.organizationMembers {
if member.OrganizationID == orgID {
members = append(members, member)
}
}
return members
}
// getEveryoneGroupMembers fetches all the users in an organization.
func (q *FakeQuerier) getEveryoneGroupMembers(orgID uuid.UUID) []database.User {
var (
everyone []database.User
orgMembers = q.getOrganizationMember(orgID)
)
for _, member := range orgMembers {
user, err := q.GetUserByID(context.TODO(), member.UserID)
if err != nil {
return nil
}
everyone = append(everyone, user)
}
return everyone
}
// isEveryoneGroup returns true if the provided ID matches
// an organization ID.
func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
for _, org := range q.organizations {
if org.ID == id {
return true
}
}
return false
}
func (*FakeQuerier) AcquireLock(_ context.Context, _ int64) error {
return xerrors.New("AcquireLock must only be called within a transaction")
}
@ -1378,13 +1416,17 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr
return database.Group{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetGroupMembers(_ context.Context, groupID uuid.UUID) ([]database.User, error) {
func (q *FakeQuerier) GetGroupMembers(_ context.Context, id uuid.UUID) ([]database.User, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
if q.isEveryoneGroup(id) {
return q.getEveryoneGroupMembers(id), nil
}
var members []database.GroupMember
for _, member := range q.groupMembers {
if member.GroupID == groupID {
if member.GroupID == id {
members = append(members, member)
}
}
@ -1403,14 +1445,13 @@ func (q *FakeQuerier) GetGroupMembers(_ context.Context, groupID uuid.UUID) ([]d
return users, nil
}
func (q *FakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationID uuid.UUID) ([]database.Group, error) {
func (q *FakeQuerier) GetGroupsByOrganizationID(_ context.Context, id uuid.UUID) ([]database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var groups []database.Group
groups := make([]database.Group, 0, len(q.groups))
for _, group := range q.groups {
// Omit the allUsers group.
if group.OrganizationID == organizationID && group.ID != organizationID {
if group.OrganizationID == id {
groups = append(groups, group)
}
}
@ -1840,9 +1881,17 @@ func (q *FakeQuerier) GetQuotaAllowanceForUser(_ context.Context, userID uuid.UU
for _, group := range q.groups {
if group.ID == member.GroupID {
sum += int64(group.QuotaAllowance)
continue
}
}
}
// Grab the quota for the Everyone group.
for _, group := range q.groups {
if group.ID == group.OrganizationID {
sum += int64(group.QuotaAllowance)
break
}
}
return sum, nil
}
@ -3548,7 +3597,7 @@ func (q *FakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
func (q *FakeQuerier) InsertAllUsersGroup(ctx context.Context, orgID uuid.UUID) (database.Group, error) {
return q.InsertGroup(ctx, database.InsertGroupParams{
ID: orgID,
Name: database.AllUsersGroup,
Name: database.EveryoneGroup,
DisplayName: "",
OrganizationID: orgID,
})

View File

@ -84,7 +84,7 @@ func (g Group) Auditable(users []User) AuditableGroup {
}
}
const AllUsersGroup = "Everyone"
const EveryoneGroup = "Everyone"
func (s APIKeyScope) ToRBAC() rbac.ScopeName {
switch s {
@ -362,3 +362,7 @@ func ConvertWorkspaceRows(rows []GetWorkspacesRow) []Workspace {
return workspaces
}
func (g Group) IsEveryone() bool {
return g.ID == g.OrganizationID
}

View File

@ -72,6 +72,8 @@ type sqlcQuerier interface {
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
// If the group is a user made group, then we need to check the group_members table.
// If it is the "Everyone" group, then we need to check the organization_members table.
GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error)
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error)

View File

@ -1069,18 +1069,29 @@ SELECT
users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at, users.quiet_hours_schedule
FROM
users
JOIN
LEFT JOIN
group_members
ON
users.id = group_members.user_id
WHERE
group_members.user_id = users.id AND
group_members.group_id = $1
LEFT JOIN
organization_members
ON
organization_members.user_id = users.id AND
organization_members.organization_id = $1
WHERE
-- In either case, the group_id will only match an org or a group.
(group_members.group_id = $1
OR
organization_members.organization_id = $1)
AND
users.status = 'active'
AND
users.deleted = 'false'
`
// If the group is a user made group, then we need to check the group_members table.
// If it is the "Everyone" group, then we need to check the organization_members table.
func (q *sqlQuerier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) {
rows, err := q.db.QueryContext(ctx, getGroupMembers, groupID)
if err != nil {
@ -1244,8 +1255,6 @@ FROM
groups
WHERE
organization_id = $1
AND
id != $1
`
func (q *sqlQuerier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error) {
@ -3398,11 +3407,13 @@ const getQuotaAllowanceForUser = `-- name: GetQuotaAllowanceForUser :one
SELECT
coalesce(SUM(quota_allowance), 0)::BIGINT
FROM
group_members gm
JOIN groups g ON
groups g
LEFT JOIN group_members gm ON
g.id = gm.group_id
WHERE
user_id = $1
OR
g.id = g.organization_id
`
func (q *sqlQuerier) GetQuotaAllowanceForUser(ctx context.Context, userID uuid.UUID) (int64, error) {

View File

@ -3,12 +3,23 @@ SELECT
users.*
FROM
users
JOIN
-- If the group is a user made group, then we need to check the group_members table.
LEFT JOIN
group_members
ON
users.id = group_members.user_id
group_members.user_id = users.id AND
group_members.group_id = @group_id
-- If it is the "Everyone" group, then we need to check the organization_members table.
LEFT JOIN
organization_members
ON
organization_members.user_id = users.id AND
organization_members.organization_id = @group_id
WHERE
group_members.group_id = $1
-- In either case, the group_id will only match an org or a group.
(group_members.group_id = @group_id
OR
organization_members.organization_id = @group_id)
AND
users.status = 'active'
AND

View File

@ -26,9 +26,7 @@ SELECT
FROM
groups
WHERE
organization_id = $1
AND
id != $1;
organization_id = $1;
-- name: InsertGroup :one
INSERT INTO groups (

View File

@ -2,11 +2,13 @@
SELECT
coalesce(SUM(quota_allowance), 0)::BIGINT
FROM
group_members gm
JOIN groups g ON
groups g
LEFT JOIN group_members gm ON
g.id = gm.group_id
WHERE
user_id = $1;
user_id = $1
OR
g.id = g.organization_id;
-- name: GetQuotaConsumedForUser :one
WITH latest_builds AS (

49
coderd/database/tx.go Normal file
View File

@ -0,0 +1,49 @@
package database
import (
"database/sql"
"github.com/lib/pq"
"golang.org/x/xerrors"
)
const maxRetries = 5
// ReadModifyUpdate is a helper function to run a db transaction that reads some
// object(s), modifies some of the data, and writes the modified object(s) back
// to the database. It is run in a transaction at RepeatableRead isolation so
// that if another database client also modifies the data we are writing and
// commits, then the transaction is rolled back and restarted.
//
// This is needed because we typically read all object columns, modify some
// subset, and then write all columns. Consider an object with columns A, B and
// initial values A=1, B=1. Two database clients work simultaneously, with one
// client attempting to set A=2, and another attempting to set B=2. They both
// initially read A=1, B=1, and then one writes A=2, B=1, and the other writes
// A=1, B=2. With default PostgreSQL isolation of ReadCommitted, both of these
// transactions would succeed and we end up with either A=2, B=1 or A=1, B=2.
// One or other client gets their transaction wiped out even though the data
// they wanted to change didn't conflict.
//
// If we run at RepeatableRead isolation, then one or other transaction will
// fail. Let's say the transaction that sets A=2 succeeds. Then the first B=2
// transaction fails, but here we retry. The second attempt we read A=2, B=1,
// then write A=2, B=2 as desired, and this succeeds.
func ReadModifyUpdate(db Store, f func(tx Store) error,
) error {
var err error
for retries := 0; retries < maxRetries; retries++ {
err = db.InTx(f, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
})
var pqe *pq.Error
if xerrors.As(err, &pqe) {
if pqe.Code == "40001" {
// serialization error, retry
continue
}
}
return err
}
return xerrors.Errorf("too many errors; last error: %w", err)
}

View File

@ -0,0 +1,81 @@
package database_test
import (
"database/sql"
"testing"
"github.com/golang/mock/gomock"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbmock"
)
func TestReadModifyUpdate_OK(t *testing.T) {
t.Parallel()
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().
InTx(gomock.Any(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}).
Times(1).
Return(nil)
err := database.ReadModifyUpdate(mDB, func(tx database.Store) error {
return nil
})
require.NoError(t, err)
}
func TestReadModifyUpdate_RetryOK(t *testing.T) {
t.Parallel()
mDB := dbmock.NewMockStore(gomock.NewController(t))
firstUpdate := mDB.EXPECT().
InTx(gomock.Any(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}).
Times(1).
Return(&pq.Error{Code: pq.ErrorCode("40001")})
mDB.EXPECT().
InTx(gomock.Any(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}).
After(firstUpdate).
Times(1).
Return(nil)
err := database.ReadModifyUpdate(mDB, func(tx database.Store) error {
return nil
})
require.NoError(t, err)
}
func TestReadModifyUpdate_HardError(t *testing.T) {
t.Parallel()
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().
InTx(gomock.Any(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}).
Times(1).
Return(xerrors.New("a bad thing happened"))
err := database.ReadModifyUpdate(mDB, func(tx database.Store) error {
return nil
})
require.ErrorContains(t, err, "a bad thing happened")
}
func TestReadModifyUpdate_TooManyRetries(t *testing.T) {
t.Parallel()
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().
InTx(gomock.Any(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}).
Times(5).
Return(&pq.Error{Code: pq.ErrorCode("40001")})
err := database.ReadModifyUpdate(mDB, func(tx database.Store) error {
return nil
})
require.ErrorContains(t, err, "too many errors")
}