mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
chore: support multi-org group sync with runtime configuration (#14578)
- Implement multi-org group sync - Implement runtime configuration to change sync behavior - Legacy group sync migrated to new package
This commit is contained in:
@ -181,7 +181,6 @@ type Options struct {
|
||||
NetworkTelemetryBatchFrequency time.Duration
|
||||
NetworkTelemetryBatchMaxSize int
|
||||
SwaggerEndpoint bool
|
||||
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error
|
||||
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
|
||||
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
|
||||
@ -276,13 +275,6 @@ func New(options *Options) *API {
|
||||
if options.Entitlements == nil {
|
||||
options.Entitlements = entitlements.New()
|
||||
}
|
||||
if options.IDPSync == nil {
|
||||
options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{
|
||||
OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(),
|
||||
OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value,
|
||||
OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(),
|
||||
})
|
||||
}
|
||||
if options.NewTicker == nil {
|
||||
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
|
||||
ticker := time.NewTicker(duration)
|
||||
@ -318,6 +310,10 @@ func New(options *Options) *API {
|
||||
options.AccessControlStore,
|
||||
)
|
||||
|
||||
if options.IDPSync == nil {
|
||||
options.IDPSync = idpsync.NewAGPLSync(options.Logger, options.RuntimeConfig, idpsync.FromDeploymentValues(options.DeploymentValues))
|
||||
}
|
||||
|
||||
experiments := ReadExperiments(
|
||||
options.Logger, options.DeploymentValues.Experiments.Value(),
|
||||
)
|
||||
@ -377,16 +373,6 @@ func New(options *Options) *API {
|
||||
if options.TracerProvider == nil {
|
||||
options.TracerProvider = trace.NewNoopTracerProvider()
|
||||
}
|
||||
if options.SetUserGroups == nil {
|
||||
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
|
||||
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
|
||||
slog.F("user_id", userID),
|
||||
slog.F("groups", orgGroupNames),
|
||||
slog.F("create_missing_groups", createMissingGroups),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if options.SetUserSiteRoles == nil {
|
||||
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
|
||||
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
|
||||
|
25
coderd/coderdtest/uuids.go
Normal file
25
coderd/coderdtest/uuids.go
Normal file
@ -0,0 +1,25 @@
|
||||
package coderdtest
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
// DeterministicUUIDGenerator allows "naming" uuids for unit tests.
|
||||
// An example of where this is useful, is when a tabled test references
|
||||
// a UUID that is not yet known. An alternative to this would be to
|
||||
// hard code some UUID strings, but these strings are not human friendly.
|
||||
type DeterministicUUIDGenerator struct {
|
||||
Named map[string]uuid.UUID
|
||||
}
|
||||
|
||||
func NewDeterministicUUIDGenerator() *DeterministicUUIDGenerator {
|
||||
return &DeterministicUUIDGenerator{
|
||||
Named: make(map[string]uuid.UUID),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DeterministicUUIDGenerator) ID(name string) uuid.UUID {
|
||||
if v, ok := d.Named[name]; ok {
|
||||
return v
|
||||
}
|
||||
d.Named[name] = uuid.New()
|
||||
return d.Named[name]
|
||||
}
|
17
coderd/coderdtest/uuids_test.go
Normal file
17
coderd/coderdtest/uuids_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package coderdtest_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
)
|
||||
|
||||
func TestDeterministicUUIDGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||
require.Equal(t, ids.ID("g1"), ids.ID("g1"))
|
||||
require.NotEqual(t, ids.ID("g1"), ids.ID("g2"))
|
||||
}
|
@ -2892,6 +2892,14 @@ func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams)
|
||||
return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||
// This is used by OIDC sync. So only used by a system user.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.InsertUserGroupsByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
|
||||
// This will add the user to all named groups. This counts as updating a group.
|
||||
// NOTE: instead of checking if the user has permission to update each group, we instead
|
||||
@ -3100,6 +3108,14 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.RemoveUserFromAllGroups(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||
// This is a system function to clear user groups in group sync.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.RemoveUserFromGroups(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
|
@ -388,6 +388,17 @@ func (s *MethodTestSuite) TestGroup() {
|
||||
GroupNames: slice.New(g1.Name, g2.Name),
|
||||
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("InsertUserGroupsByID", s.Subtest(func(db database.Store, check *expects) {
|
||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
u1 := dbgen.User(s.T(), db, database.User{})
|
||||
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
|
||||
check.Args(database.InsertUserGroupsByIDParams{
|
||||
UserID: u1.ID,
|
||||
GroupIds: slice.New(g1.ID, g2.ID),
|
||||
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
|
||||
}))
|
||||
s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) {
|
||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
u1 := dbgen.User(s.T(), db, database.User{})
|
||||
@ -397,6 +408,18 @@ func (s *MethodTestSuite) TestGroup() {
|
||||
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
|
||||
check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("RemoveUserFromGroups", s.Subtest(func(db database.Store, check *expects) {
|
||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
u1 := dbgen.User(s.T(), db, database.User{})
|
||||
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
|
||||
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
|
||||
check.Args(database.RemoveUserFromGroupsParams{
|
||||
UserID: u1.ID,
|
||||
GroupIds: []uuid.UUID{g1.ID, g2.ID},
|
||||
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
|
||||
}))
|
||||
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {
|
||||
g := dbgen.Group(s.T(), db, database.Group{})
|
||||
check.Args(database.UpdateGroupByIDParams{
|
||||
|
@ -2695,18 +2695,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
groupIDs := make(map[uuid.UUID]struct{})
|
||||
userGroupIDs := make(map[uuid.UUID]struct{})
|
||||
if arg.HasMemberID != uuid.Nil {
|
||||
for _, member := range q.groupMembers {
|
||||
if member.UserID == arg.HasMemberID {
|
||||
groupIDs[member.GroupID] = struct{}{}
|
||||
userGroupIDs[member.GroupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the everyone group
|
||||
for _, orgMember := range q.organizationMembers {
|
||||
if orgMember.UserID == arg.HasMemberID {
|
||||
groupIDs[orgMember.OrganizationID] = struct{}{}
|
||||
userGroupIDs[orgMember.OrganizationID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2718,11 +2718,15 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok := groupIDs[group.ID]
|
||||
_, ok := userGroupIDs[group.ID]
|
||||
if arg.HasMemberID != uuid.Nil && !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
orgDetails, ok := orgDetailsCache[group.ID]
|
||||
if !ok {
|
||||
for _, org := range q.organizations {
|
||||
@ -7015,7 +7019,37 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||
err := validateDatabaseType(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
var groupIDs []uuid.UUID
|
||||
for _, group := range q.groups {
|
||||
for _, groupID := range arg.GroupIds {
|
||||
if group.ID == groupID {
|
||||
q.groupMembers = append(q.groupMembers, database.GroupMemberTable{
|
||||
UserID: arg.UserID,
|
||||
GroupID: groupID,
|
||||
})
|
||||
groupIDs = append(groupIDs, group.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error {
|
||||
err := validateDatabaseType(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
@ -7607,6 +7641,34 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||
err := validateDatabaseType(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
||||
removed := make([]uuid.UUID, 0)
|
||||
q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool {
|
||||
// Delete all group members that match the arguments.
|
||||
if groupMember.UserID != arg.UserID {
|
||||
// Not the right user, ignore.
|
||||
return false
|
||||
}
|
||||
|
||||
if !slices.Contains(arg.GroupIds, groupMember.GroupID) {
|
||||
return false
|
||||
}
|
||||
|
||||
removed = append(removed, groupMember.GroupID)
|
||||
return true
|
||||
})
|
||||
|
||||
return removed, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
@ -1789,6 +1789,13 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertUserGroupsByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
|
||||
start := time.Now()
|
||||
err := m.s.InsertUserGroupsByName(ctx, arg)
|
||||
@ -1943,6 +1950,13 @@ func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.U
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.RemoveUserFromGroups(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("RemoveUserFromGroups").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||
|
@ -3766,6 +3766,21 @@ func (mr *MockStoreMockRecorder) InsertUser(arg0, arg1 any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1)
|
||||
}
|
||||
|
||||
// InsertUserGroupsByID mocks base method.
|
||||
func (m *MockStore) InsertUserGroupsByID(arg0 context.Context, arg1 database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertUserGroupsByID", arg0, arg1)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertUserGroupsByID indicates an expected call of InsertUserGroupsByID.
|
||||
func (mr *MockStoreMockRecorder) InsertUserGroupsByID(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), arg0, arg1)
|
||||
}
|
||||
|
||||
// InsertUserGroupsByName mocks base method.
|
||||
func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@ -4103,6 +4118,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemoveUserFromGroups mocks base method.
|
||||
func (m *MockStore) RemoveUserFromGroups(arg0 context.Context, arg1 database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveUserFromGroups", arg0, arg1)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RemoveUserFromGroups indicates an expected call of RemoveUserFromGroups.
|
||||
func (mr *MockStoreMockRecorder) RemoveUserFromGroups(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), arg0, arg1)
|
||||
}
|
||||
|
||||
// RevokeDBCryptKey mocks base method.
|
||||
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -369,6 +369,9 @@ type sqlcQuerier interface {
|
||||
InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error)
|
||||
InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error)
|
||||
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
|
||||
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
|
||||
// If there is a conflict, the user is already a member
|
||||
InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error)
|
||||
// InsertUserGroupsByName adds a user to all provided groups, if they exist.
|
||||
InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error
|
||||
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
|
||||
@ -396,6 +399,7 @@ type sqlcQuerier interface {
|
||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
|
||||
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
|
||||
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
|
@ -1446,6 +1446,56 @@ func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMembe
|
||||
return err
|
||||
}
|
||||
|
||||
const insertUserGroupsByID = `-- name: InsertUserGroupsByID :many
|
||||
WITH groups AS (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
groups
|
||||
WHERE
|
||||
groups.id = ANY($2 :: uuid [])
|
||||
)
|
||||
INSERT INTO
|
||||
group_members (user_id, group_id)
|
||||
SELECT
|
||||
$1,
|
||||
groups.id
|
||||
FROM
|
||||
groups
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING group_id
|
||||
`
|
||||
|
||||
type InsertUserGroupsByIDParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
|
||||
}
|
||||
|
||||
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
|
||||
// If there is a conflict, the user is already a member
|
||||
func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var group_id uuid.UUID
|
||||
if err := rows.Scan(&group_id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, group_id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec
|
||||
WITH groups AS (
|
||||
SELECT
|
||||
@ -1489,6 +1539,43 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU
|
||||
return err
|
||||
}
|
||||
|
||||
const removeUserFromGroups = `-- name: RemoveUserFromGroups :many
|
||||
DELETE FROM
|
||||
group_members
|
||||
WHERE
|
||||
user_id = $1 AND
|
||||
group_id = ANY($2 :: uuid [])
|
||||
RETURNING group_id
|
||||
`
|
||||
|
||||
type RemoveUserFromGroupsParams struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var group_id uuid.UUID
|
||||
if err := rows.Scan(&group_id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, group_id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const deleteGroupByID = `-- name: DeleteGroupByID :exec
|
||||
DELETE FROM
|
||||
groups
|
||||
@ -1592,11 +1679,16 @@ WHERE
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN
|
||||
groups.name = ANY($3)
|
||||
ELSE true
|
||||
END
|
||||
`
|
||||
|
||||
type GetGroupsParams struct {
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"`
|
||||
GroupNames []string `db:"group_names" json:"group_names"`
|
||||
}
|
||||
|
||||
type GetGroupsRow struct {
|
||||
@ -1606,7 +1698,7 @@ type GetGroupsRow struct {
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID)
|
||||
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID, pq.Array(arg.GroupNames))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -29,12 +29,41 @@ SELECT
|
||||
FROM
|
||||
groups;
|
||||
|
||||
-- InsertUserGroupsByID adds a user to all provided groups, if they exist.
|
||||
-- name: InsertUserGroupsByID :many
|
||||
WITH groups AS (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
groups
|
||||
WHERE
|
||||
groups.id = ANY(@group_ids :: uuid [])
|
||||
)
|
||||
INSERT INTO
|
||||
group_members (user_id, group_id)
|
||||
SELECT
|
||||
@user_id,
|
||||
groups.id
|
||||
FROM
|
||||
groups
|
||||
-- If there is a conflict, the user is already a member
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING group_id;
|
||||
|
||||
-- name: RemoveUserFromAllGroups :exec
|
||||
DELETE FROM
|
||||
group_members
|
||||
WHERE
|
||||
user_id = @user_id;
|
||||
|
||||
-- name: RemoveUserFromGroups :many
|
||||
DELETE FROM
|
||||
group_members
|
||||
WHERE
|
||||
user_id = @user_id AND
|
||||
group_id = ANY(@group_ids :: uuid [])
|
||||
RETURNING group_id;
|
||||
|
||||
-- name: InsertGroupMember :exec
|
||||
INSERT INTO
|
||||
group_members (user_id, group_id)
|
||||
|
@ -52,6 +52,10 @@ WHERE
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN
|
||||
groups.name = ANY(@group_names)
|
||||
ELSE true
|
||||
END
|
||||
;
|
||||
|
||||
-- name: InsertGroup :one
|
||||
|
416
coderd/idpsync/group.go
Normal file
416
coderd/idpsync/group.go
Normal file
@ -0,0 +1,416 @@
|
||||
package idpsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
)
|
||||
|
||||
type GroupParams struct {
|
||||
// SyncEnabled if false will skip syncing the user's groups
|
||||
SyncEnabled bool
|
||||
MergedClaims jwt.MapClaims
|
||||
}
|
||||
|
||||
func (AGPLIDPSync) GroupSyncEnabled() bool {
|
||||
// AGPL does not support syncing groups.
|
||||
return false
|
||||
}
|
||||
|
||||
func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] {
|
||||
return s.Group
|
||||
}
|
||||
|
||||
func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) {
|
||||
return GroupParams{
|
||||
SyncEnabled: s.GroupSyncEnabled(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error {
|
||||
// Nothing happens if sync is not enabled
|
||||
if !params.SyncEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gocritic // all syncing is done as a system user
|
||||
ctx = dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
// Only care about the default org for deployment settings if the
|
||||
// legacy deployment settings exist.
|
||||
defaultOrgID := uuid.Nil
|
||||
// Default organization is configured via legacy deployment values
|
||||
if s.DeploymentSyncSettings.Legacy.GroupField != "" {
|
||||
defaultOrganization, err := db.GetDefaultOrganization(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get default organization: %w", err)
|
||||
}
|
||||
defaultOrgID = defaultOrganization.ID
|
||||
}
|
||||
|
||||
err := db.InTx(func(tx database.Store) error {
|
||||
userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
|
||||
HasMemberID: user.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get user groups: %w", err)
|
||||
}
|
||||
|
||||
// Figure out which organizations the user is a member of.
|
||||
// The "Everyone" group is always included, so we can infer organization
|
||||
// membership via the groups the user is in.
|
||||
userOrgs := make(map[uuid.UUID][]database.GetGroupsRow)
|
||||
for _, g := range userGroups {
|
||||
g := g
|
||||
userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g)
|
||||
}
|
||||
|
||||
// For each org, we need to fetch the sync settings
|
||||
// This loop also handles any legacy settings for the default
|
||||
// organization.
|
||||
orgSettings := make(map[uuid.UUID]GroupSyncSettings)
|
||||
for orgID := range userOrgs {
|
||||
orgResolver := s.Manager.OrganizationResolver(tx, orgID)
|
||||
settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver)
|
||||
if err != nil {
|
||||
if !xerrors.Is(err, runtimeconfig.ErrEntryNotFound) {
|
||||
return xerrors.Errorf("resolve group sync settings: %w", err)
|
||||
}
|
||||
// Default to not being configured
|
||||
settings = &GroupSyncSettings{}
|
||||
}
|
||||
|
||||
// Legacy deployment settings will override empty settings.
|
||||
if orgID == defaultOrgID && settings.Field == "" {
|
||||
settings = &GroupSyncSettings{
|
||||
Field: s.Legacy.GroupField,
|
||||
LegacyNameMapping: s.Legacy.GroupMapping,
|
||||
RegexFilter: s.Legacy.GroupFilter,
|
||||
AutoCreateMissing: s.Legacy.CreateMissingGroups,
|
||||
}
|
||||
}
|
||||
orgSettings[orgID] = *settings
|
||||
}
|
||||
|
||||
// groupIDsToAdd & groupIDsToRemove are the final group differences
|
||||
// needed to be applied to user. The loop below will iterate over all
|
||||
// organizations the user is in, and determine the diffs.
|
||||
// The diffs are applied as a batch sql query, rather than each
|
||||
// organization having to execute a query.
|
||||
groupIDsToAdd := make([]uuid.UUID, 0)
|
||||
groupIDsToRemove := make([]uuid.UUID, 0)
|
||||
// For each org, determine which groups the user should land in
|
||||
for orgID, settings := range orgSettings {
|
||||
if settings.Field == "" {
|
||||
// No group sync enabled for this org, so do nothing.
|
||||
// The user can remain in their groups for this org.
|
||||
continue
|
||||
}
|
||||
|
||||
// expectedGroups is the set of groups the IDP expects the
|
||||
// user to be a member of.
|
||||
expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims)
|
||||
if err != nil {
|
||||
s.Logger.Debug(ctx, "failed to parse claims for groups",
|
||||
slog.F("organization_field", s.GroupField),
|
||||
slog.F("organization_id", orgID),
|
||||
slog.Error(err),
|
||||
)
|
||||
// Unsure where to raise this error on the UI or database.
|
||||
// TODO: This error prevents group sync, but we have no way
|
||||
// to raise this to an org admin. Come up with a solution to
|
||||
// notify the admin and user of this issue.
|
||||
continue
|
||||
}
|
||||
// Everyone group is always implied, so include it.
|
||||
expectedGroups = append(expectedGroups, ExpectedGroup{
|
||||
OrganizationID: orgID,
|
||||
GroupID: &orgID,
|
||||
})
|
||||
|
||||
// Now we know what groups the user should be in for a given org,
|
||||
// determine if we have to do any group updates to sync the user's
|
||||
// state.
|
||||
existingGroups := userOrgs[orgID]
|
||||
existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup {
|
||||
return ExpectedGroup{
|
||||
OrganizationID: orgID,
|
||||
GroupID: &f.Group.ID,
|
||||
GroupName: &f.Group.Name,
|
||||
}
|
||||
})
|
||||
|
||||
add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool {
|
||||
return a.Equal(b)
|
||||
})
|
||||
|
||||
for _, r := range remove {
|
||||
if r.GroupID == nil {
|
||||
// This should never happen. All group removals come from the
|
||||
// existing set, which come from the db. All groups from the
|
||||
// database have IDs. This code is purely defensive.
|
||||
detail := "user:" + user.Username
|
||||
if r.GroupName != nil {
|
||||
detail += fmt.Sprintf(" from group %s", *r.GroupName)
|
||||
}
|
||||
return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail)
|
||||
}
|
||||
groupIDsToRemove = append(groupIDsToRemove, *r.GroupID)
|
||||
}
|
||||
|
||||
// HandleMissingGroups will add the new groups to the org if
|
||||
// the settings specify. It will convert all group names into uuids
|
||||
// for easier assignment.
|
||||
// TODO: This code should be batched at the end of the for loop.
|
||||
// Optimizing this is being pushed because if AutoCreate is disabled,
|
||||
// this code will only add cost on the first login for each user.
|
||||
// AutoCreate is usually disabled for large deployments.
|
||||
// For small deployments, this is less of a problem.
|
||||
assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("handle missing groups: %w", err)
|
||||
}
|
||||
|
||||
groupIDsToAdd = append(groupIDsToAdd, assignGroups...)
|
||||
}
|
||||
|
||||
// ApplyGroupDifference will take the total adds and removes, and apply
|
||||
// them.
|
||||
err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("apply group difference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyGroupDifference will add and remove the user from the specified groups.
|
||||
func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error {
|
||||
if len(removeIDs) > 0 {
|
||||
removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{
|
||||
UserID: user.ID,
|
||||
GroupIds: removeIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err)
|
||||
}
|
||||
if len(removedGroupIDs) != len(removeIDs) {
|
||||
s.Logger.Debug(ctx, "user not removed from expected number of groups",
|
||||
slog.F("user_id", user.ID),
|
||||
slog.F("groups_removed_count", len(removedGroupIDs)),
|
||||
slog.F("expected_count", len(removeIDs)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(add) > 0 {
|
||||
add = slice.Unique(add)
|
||||
// Defensive programming to only insert uniques.
|
||||
assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{
|
||||
UserID: user.ID,
|
||||
GroupIds: add,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user into %d groups: %w", len(add), err)
|
||||
}
|
||||
if len(assignedGroupIDs) != len(add) {
|
||||
s.Logger.Debug(ctx, "user not assigned to expected number of groups",
|
||||
slog.F("user_id", user.ID),
|
||||
slog.F("groups_assigned_count", len(assignedGroupIDs)),
|
||||
slog.F("expected_count", len(add)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type GroupSyncSettings struct {
|
||||
// Field selects the claim field to be used as the created user's
|
||||
// groups. If the group field is the empty string, then no group updates
|
||||
// will ever come from the OIDC provider.
|
||||
Field string `json:"field"`
|
||||
// Mapping maps from an OIDC group --> Coder group ID
|
||||
Mapping map[string][]uuid.UUID `json:"mapping"`
|
||||
// RegexFilter is a regular expression that filters the groups returned by
|
||||
// the OIDC provider. Any group not matched by this regex will be ignored.
|
||||
// If the group filter is nil, then no group filtering will occur.
|
||||
RegexFilter *regexp.Regexp `json:"regex_filter"`
|
||||
// AutoCreateMissing controls whether groups returned by the OIDC provider
|
||||
// are automatically created in Coder if they are missing.
|
||||
AutoCreateMissing bool `json:"auto_create_missing_groups"`
|
||||
// LegacyNameMapping is deprecated. It remaps an IDP group name to
|
||||
// a Coder group name. Since configuration is now done at runtime,
|
||||
// group IDs are used to account for group renames.
|
||||
// For legacy configurations, this config option has to remain.
|
||||
// Deprecated: Use Mapping instead.
|
||||
LegacyNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"`
|
||||
}
|
||||
|
||||
func (s *GroupSyncSettings) Set(v string) error {
|
||||
return json.Unmarshal([]byte(v), s)
|
||||
}
|
||||
|
||||
func (s *GroupSyncSettings) String() string {
|
||||
return runtimeconfig.JSONString(s)
|
||||
}
|
||||
|
||||
type ExpectedGroup struct {
|
||||
OrganizationID uuid.UUID
|
||||
GroupID *uuid.UUID
|
||||
GroupName *string
|
||||
}
|
||||
|
||||
// Equal compares two ExpectedGroups. The org id must be the same.
|
||||
// If the group ID is set, it will be compared and take priority, ignoring the
|
||||
// name value. So 2 groups with the same ID but different names will be
|
||||
// considered equal.
|
||||
func (a ExpectedGroup) Equal(b ExpectedGroup) bool {
|
||||
// Must match
|
||||
if a.OrganizationID != b.OrganizationID {
|
||||
return false
|
||||
}
|
||||
// Only the name or the name needs to be checked, priority is given to the ID.
|
||||
if a.GroupID != nil && b.GroupID != nil {
|
||||
return *a.GroupID == *b.GroupID
|
||||
}
|
||||
if a.GroupName != nil && b.GroupName != nil {
|
||||
return *a.GroupName == *b.GroupName
|
||||
}
|
||||
|
||||
// If everything is nil, it is equal. Although a bit pointless
|
||||
if a.GroupID == nil && b.GroupID == nil &&
|
||||
a.GroupName == nil && b.GroupName == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ParseClaims will take the merged claims from the IDP and return the groups
|
||||
// the user is expected to be a member of. The expected group can either be a
|
||||
// name or an ID.
|
||||
// It is unfortunate we cannot use exclusively names or exclusively IDs.
|
||||
// When configuring though, if a group is mapped from "A" -> "UUID 1234", and
|
||||
// the group "UUID 1234" is renamed, we want to maintain the mapping.
|
||||
// We have to keep names because group sync supports syncing groups by name if
|
||||
// the external IDP group name matches the Coder one.
|
||||
func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) {
|
||||
groupsRaw, ok := mergedClaims[s.Field]
|
||||
if !ok {
|
||||
return []ExpectedGroup{}, nil
|
||||
}
|
||||
|
||||
parsedGroups, err := ParseStringSliceClaim(groupsRaw)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse groups field, unexpected type %T: %w", groupsRaw, err)
|
||||
}
|
||||
|
||||
groups := make([]ExpectedGroup, 0)
|
||||
for _, group := range parsedGroups {
|
||||
group := group
|
||||
|
||||
// Legacy group mappings happen before the regex filter.
|
||||
mappedGroupName, ok := s.LegacyNameMapping[group]
|
||||
if ok {
|
||||
group = mappedGroupName
|
||||
}
|
||||
|
||||
// Only allow through groups that pass the regex
|
||||
if s.RegexFilter != nil {
|
||||
if !s.RegexFilter.MatchString(group) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
mappedGroupIDs, ok := s.Mapping[group]
|
||||
if ok {
|
||||
for _, gid := range mappedGroupIDs {
|
||||
gid := gid
|
||||
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group})
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// HandleMissingGroups ensures all ExpectedGroups convert to uuids.
|
||||
// Groups can be referenced by name via legacy params or IDP group names.
|
||||
// These group names are converted to IDs for easier assignment.
|
||||
// Missing groups are created if AutoCreate is enabled.
|
||||
// TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
|
||||
func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) {
|
||||
// All expected that are missing IDs means the group does not exist
|
||||
// in the database, or it is a legacy mapping, and we need to do a lookup.
|
||||
var missingGroups []string
|
||||
addIDs := make([]uuid.UUID, 0)
|
||||
|
||||
for _, expected := range add {
|
||||
if expected.GroupID == nil && expected.GroupName != nil {
|
||||
missingGroups = append(missingGroups, *expected.GroupName)
|
||||
} else if expected.GroupID != nil {
|
||||
// Keep the IDs to sync the groups.
|
||||
addIDs = append(addIDs, *expected.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
if s.AutoCreateMissing && len(missingGroups) > 0 {
|
||||
// Insert any missing groups. If the groups already exist, this is a noop.
|
||||
_, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{
|
||||
OrganizationID: orgID,
|
||||
Source: database.GroupSourceOidc,
|
||||
GroupNames: missingGroups,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert missing groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch any missing groups by name. If they exist, their IDs will be
|
||||
// matched and returned.
|
||||
if len(missingGroups) > 0 {
|
||||
// Do name lookups for all groups that are missing IDs.
|
||||
newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
|
||||
OrganizationID: orgID,
|
||||
HasMemberID: uuid.UUID{},
|
||||
GroupNames: missingGroups,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get groups by names: %w", err)
|
||||
}
|
||||
for _, g := range newGroups {
|
||||
addIDs = append(addIDs, g.Group.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return addIDs, nil
|
||||
}
|
||||
|
||||
func ConvertAllowList(allowList []string) map[string]struct{} {
|
||||
allowMap := make(map[string]struct{}, len(allowList))
|
||||
for _, group := range allowList {
|
||||
allowMap[group] = struct{}{}
|
||||
}
|
||||
return allowMap
|
||||
}
|
814
coderd/idpsync/group_test.go
Normal file
814
coderd/idpsync/group_test.go
Normal file
@ -0,0 +1,814 @@
|
||||
package idpsync_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestParseGroupClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||
runtimeconfig.NewManager(),
|
||||
idpsync.DeploymentSyncSettings{})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
|
||||
require.Nil(t, err)
|
||||
|
||||
require.False(t, params.SyncEnabled)
|
||||
})
|
||||
|
||||
// AllowList has no effect in AGPL
|
||||
t.Run("AllowList", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||
runtimeconfig.NewManager(),
|
||||
idpsync.DeploymentSyncSettings{
|
||||
GroupField: "groups",
|
||||
GroupAllowList: map[string]struct{}{
|
||||
"foo": {},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
|
||||
require.Nil(t, err)
|
||||
require.False(t, params.SyncEnabled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroupSyncTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Last checked, takes 30s with postgres on a fast machine.
|
||||
if dbtestutil.WillUsePostgres() {
|
||||
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
|
||||
}
|
||||
|
||||
userClaims := jwt.MapClaims{
|
||||
"groups": []string{
|
||||
"foo", "bar", "baz",
|
||||
"create-bar", "create-baz",
|
||||
"legacy-bar",
|
||||
},
|
||||
}
|
||||
|
||||
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||
testCases := []orgSetupDefinition{
|
||||
{
|
||||
Name: "SwitchGroups",
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
Mapping: map[string][]uuid.UUID{
|
||||
"foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")},
|
||||
"bar": {ids.ID("sg-bar")},
|
||||
"baz": {ids.ID("sg-baz")},
|
||||
},
|
||||
},
|
||||
Groups: map[uuid.UUID]bool{
|
||||
uuid.New(): true,
|
||||
uuid.New(): true,
|
||||
// Extra groups
|
||||
ids.ID("sg-foo"): false,
|
||||
ids.ID("sg-foo-2"): false,
|
||||
ids.ID("sg-bar"): false,
|
||||
ids.ID("sg-baz"): false,
|
||||
},
|
||||
ExpectedGroups: []uuid.UUID{
|
||||
ids.ID("sg-foo"),
|
||||
ids.ID("sg-foo-2"),
|
||||
ids.ID("sg-bar"),
|
||||
ids.ID("sg-baz"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "StayInGroup",
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
// Only match foo, so bar does not map
|
||||
RegexFilter: regexp.MustCompile("^foo$"),
|
||||
Mapping: map[string][]uuid.UUID{
|
||||
"foo": {ids.ID("gg-foo"), uuid.New()},
|
||||
"bar": {ids.ID("gg-bar")},
|
||||
"baz": {ids.ID("gg-baz")},
|
||||
},
|
||||
},
|
||||
Groups: map[uuid.UUID]bool{
|
||||
ids.ID("gg-foo"): true,
|
||||
ids.ID("gg-bar"): false,
|
||||
},
|
||||
ExpectedGroups: []uuid.UUID{
|
||||
ids.ID("gg-foo"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "UserJoinsGroups",
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
Mapping: map[string][]uuid.UUID{
|
||||
"foo": {ids.ID("ng-foo"), uuid.New()},
|
||||
"bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")},
|
||||
"baz": {ids.ID("ng-baz")},
|
||||
},
|
||||
},
|
||||
Groups: map[uuid.UUID]bool{
|
||||
ids.ID("ng-foo"): false,
|
||||
ids.ID("ng-bar"): false,
|
||||
ids.ID("ng-bar-2"): false,
|
||||
ids.ID("ng-baz"): false,
|
||||
},
|
||||
ExpectedGroups: []uuid.UUID{
|
||||
ids.ID("ng-foo"),
|
||||
ids.ID("ng-bar"),
|
||||
ids.ID("ng-bar-2"),
|
||||
ids.ID("ng-baz"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "CreateGroups",
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
RegexFilter: regexp.MustCompile("^create"),
|
||||
AutoCreateMissing: true,
|
||||
},
|
||||
Groups: map[uuid.UUID]bool{},
|
||||
ExpectedGroupNames: []string{
|
||||
"create-bar",
|
||||
"create-baz",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "GroupNamesNoMapping",
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
RegexFilter: regexp.MustCompile(".*"),
|
||||
AutoCreateMissing: false,
|
||||
},
|
||||
GroupNames: map[string]bool{
|
||||
"foo": false,
|
||||
"bar": false,
|
||||
"goob": true,
|
||||
},
|
||||
ExpectedGroupNames: []string{
|
||||
"foo",
|
||||
"bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "NoUser",
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
Mapping: map[string][]uuid.UUID{
|
||||
// Extra ID that does not map to a group
|
||||
"foo": {ids.ID("ow-foo"), uuid.New()},
|
||||
},
|
||||
RegexFilter: nil,
|
||||
AutoCreateMissing: false,
|
||||
},
|
||||
NotMember: true,
|
||||
Groups: map[uuid.UUID]bool{
|
||||
ids.ID("ow-foo"): false,
|
||||
ids.ID("ow-bar"): false,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "NoSettingsNoUser",
|
||||
Settings: nil,
|
||||
Groups: map[uuid.UUID]bool{},
|
||||
},
|
||||
{
|
||||
Name: "LegacyMapping",
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
RegexFilter: regexp.MustCompile("^legacy"),
|
||||
LegacyNameMapping: map[string]string{
|
||||
"create-bar": "legacy-bar",
|
||||
"foo": "legacy-foo",
|
||||
"bop": "legacy-bop",
|
||||
},
|
||||
AutoCreateMissing: true,
|
||||
},
|
||||
Groups: map[uuid.UUID]bool{
|
||||
ids.ID("lg-foo"): true,
|
||||
},
|
||||
GroupNames: map[string]bool{
|
||||
"legacy-foo": false,
|
||||
"extra": true,
|
||||
"legacy-bop": true,
|
||||
},
|
||||
ExpectedGroupNames: []string{
|
||||
"legacy-bar",
|
||||
"legacy-foo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
manager := runtimeconfig.NewManager()
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||
manager,
|
||||
idpsync.DeploymentSyncSettings{
|
||||
GroupField: "groups",
|
||||
},
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
orgID := uuid.New()
|
||||
SetupOrganization(t, s, db, user, orgID, tc)
|
||||
|
||||
// Do the group sync!
|
||||
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
|
||||
SyncEnabled: true,
|
||||
MergedClaims: userClaims,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tc.Assert(t, orgID, db, user)
|
||||
})
|
||||
}
|
||||
|
||||
// AllTogether runs the entire tabled test as a singular user and
|
||||
// deployment. This tests all organizations being synced together.
|
||||
// The reason we do them individually, is that it is much easier to
|
||||
// debug a single test case.
|
||||
t.Run("AllTogether", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
manager := runtimeconfig.NewManager()
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||
manager,
|
||||
// Also sync the default org!
|
||||
idpsync.DeploymentSyncSettings{
|
||||
GroupField: "groups",
|
||||
Legacy: idpsync.DefaultOrgLegacySettings{
|
||||
GroupField: "groups",
|
||||
GroupMapping: map[string]string{
|
||||
"foo": "legacy-foo",
|
||||
"baz": "legacy-baz",
|
||||
},
|
||||
GroupFilter: regexp.MustCompile("^legacy"),
|
||||
CreateMissingGroups: true,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
var asserts []func(t *testing.T)
|
||||
// The default org is also going to do something
|
||||
def := orgSetupDefinition{
|
||||
Name: "DefaultOrg",
|
||||
GroupNames: map[string]bool{
|
||||
"legacy-foo": false,
|
||||
"legacy-baz": true,
|
||||
"random": true,
|
||||
},
|
||||
// No settings, because they come from the deployment values
|
||||
Settings: nil,
|
||||
ExpectedGroups: nil,
|
||||
ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"},
|
||||
}
|
||||
|
||||
//nolint:gocritic // testing
|
||||
defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
|
||||
require.NoError(t, err)
|
||||
SetupOrganization(t, s, db, user, defOrg.ID, def)
|
||||
asserts = append(asserts, func(t *testing.T) {
|
||||
t.Run(def.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
def.Assert(t, defOrg.ID, db, user)
|
||||
})
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
orgID := uuid.New()
|
||||
SetupOrganization(t, s, db, user, orgID, tc)
|
||||
asserts = append(asserts, func(t *testing.T) {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.Assert(t, orgID, db, user)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
asserts = append(asserts, func(t *testing.T) {
|
||||
t.Helper()
|
||||
def.Assert(t, defOrg.ID, db, user)
|
||||
})
|
||||
|
||||
// Do the group sync!
|
||||
err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{
|
||||
SyncEnabled: true,
|
||||
MergedClaims: userClaims,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, assert := range asserts {
|
||||
assert(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSyncDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if dbtestutil.WillUsePostgres() {
|
||||
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
|
||||
}
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
manager := runtimeconfig.NewManager()
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||
manager,
|
||||
idpsync.DeploymentSyncSettings{},
|
||||
)
|
||||
|
||||
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
orgID := uuid.New()
|
||||
|
||||
def := orgSetupDefinition{
|
||||
Name: "SyncDisabled",
|
||||
Groups: map[uuid.UUID]bool{
|
||||
ids.ID("foo"): true,
|
||||
ids.ID("bar"): true,
|
||||
ids.ID("baz"): false,
|
||||
ids.ID("bop"): false,
|
||||
},
|
||||
Settings: &idpsync.GroupSyncSettings{
|
||||
Field: "groups",
|
||||
Mapping: map[string][]uuid.UUID{
|
||||
"foo": {ids.ID("foo")},
|
||||
"baz": {ids.ID("baz")},
|
||||
},
|
||||
},
|
||||
ExpectedGroups: []uuid.UUID{
|
||||
ids.ID("foo"),
|
||||
ids.ID("bar"),
|
||||
},
|
||||
}
|
||||
|
||||
SetupOrganization(t, s, db, user, orgID, def)
|
||||
|
||||
// Do the group sync!
|
||||
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
|
||||
SyncEnabled: false,
|
||||
MergedClaims: jwt.MapClaims{
|
||||
"groups": []string{"baz", "bop"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
def.Assert(t, orgID, db, user)
|
||||
}
|
||||
|
||||
// TestApplyGroupDifference is mainly testing the database functions
|
||||
func TestApplyGroupDifference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||
testCase := []struct {
|
||||
Name string
|
||||
Before map[uuid.UUID]bool
|
||||
Add []uuid.UUID
|
||||
Remove []uuid.UUID
|
||||
Expect []uuid.UUID
|
||||
}{
|
||||
{
|
||||
Name: "Empty",
|
||||
},
|
||||
{
|
||||
Name: "AddFromNone",
|
||||
Before: map[uuid.UUID]bool{
|
||||
ids.ID("g1"): false,
|
||||
},
|
||||
Add: []uuid.UUID{
|
||||
ids.ID("g1"),
|
||||
},
|
||||
Expect: []uuid.UUID{
|
||||
ids.ID("g1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "AddSome",
|
||||
Before: map[uuid.UUID]bool{
|
||||
ids.ID("g1"): true,
|
||||
ids.ID("g2"): false,
|
||||
ids.ID("g3"): false,
|
||||
uuid.New(): false,
|
||||
},
|
||||
Add: []uuid.UUID{
|
||||
ids.ID("g2"),
|
||||
ids.ID("g3"),
|
||||
},
|
||||
Expect: []uuid.UUID{
|
||||
ids.ID("g1"),
|
||||
ids.ID("g2"),
|
||||
ids.ID("g3"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "RemoveAll",
|
||||
Before: map[uuid.UUID]bool{
|
||||
uuid.New(): false,
|
||||
ids.ID("g2"): true,
|
||||
ids.ID("g3"): true,
|
||||
},
|
||||
Remove: []uuid.UUID{
|
||||
ids.ID("g2"),
|
||||
ids.ID("g3"),
|
||||
},
|
||||
Expect: []uuid.UUID{},
|
||||
},
|
||||
{
|
||||
Name: "Mixed",
|
||||
Before: map[uuid.UUID]bool{
|
||||
// adds
|
||||
ids.ID("a1"): true,
|
||||
ids.ID("a2"): true,
|
||||
ids.ID("a3"): false,
|
||||
ids.ID("a4"): false,
|
||||
// removes
|
||||
ids.ID("r1"): true,
|
||||
ids.ID("r2"): true,
|
||||
ids.ID("r3"): false,
|
||||
ids.ID("r4"): false,
|
||||
// stable
|
||||
ids.ID("s1"): true,
|
||||
ids.ID("s2"): true,
|
||||
// noise
|
||||
uuid.New(): false,
|
||||
uuid.New(): false,
|
||||
},
|
||||
Add: []uuid.UUID{
|
||||
ids.ID("a1"), ids.ID("a2"),
|
||||
ids.ID("a3"), ids.ID("a4"),
|
||||
// Double up to try and confuse
|
||||
ids.ID("a1"),
|
||||
ids.ID("a4"),
|
||||
},
|
||||
Remove: []uuid.UUID{
|
||||
ids.ID("r1"), ids.ID("r2"),
|
||||
ids.ID("r3"), ids.ID("r4"),
|
||||
// Double up to try and confuse
|
||||
ids.ID("r1"),
|
||||
ids.ID("r4"),
|
||||
},
|
||||
Expect: []uuid.UUID{
|
||||
ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"),
|
||||
ids.ID("s1"), ids.ID("s2"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCase {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr := runtimeconfig.NewManager()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
//nolint:gocritic // testing
|
||||
ctx = dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_, err := db.InsertAllUsersGroup(ctx, org.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
|
||||
for gid, in := range tc.Before {
|
||||
group := dbgen.Group(t, db, database.Group{
|
||||
ID: gid,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
if in {
|
||||
_ = dbgen.GroupMember(t, db, database.GroupMemberTable{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t)))
|
||||
err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove)
|
||||
require.NoError(t, err)
|
||||
|
||||
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
||||
HasMemberID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// assert
|
||||
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
|
||||
return g.Group.ID
|
||||
})
|
||||
|
||||
// Add everyone group
|
||||
require.ElementsMatch(t, append(tc.Expect, org.ID), found)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpectedGroupEqual(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||
testCases := []struct {
|
||||
Name string
|
||||
A idpsync.ExpectedGroup
|
||||
B idpsync.ExpectedGroup
|
||||
Equal bool
|
||||
}{
|
||||
{
|
||||
Name: "Empty",
|
||||
A: idpsync.ExpectedGroup{},
|
||||
B: idpsync.ExpectedGroup{},
|
||||
Equal: true,
|
||||
},
|
||||
{
|
||||
Name: "DifferentOrgs",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: uuid.New(),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: nil,
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: uuid.New(),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: nil,
|
||||
},
|
||||
Equal: false,
|
||||
},
|
||||
{
|
||||
Name: "SameID",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: nil,
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: nil,
|
||||
},
|
||||
Equal: true,
|
||||
},
|
||||
{
|
||||
Name: "DifferentIDs",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(uuid.New()),
|
||||
GroupName: nil,
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(uuid.New()),
|
||||
GroupName: nil,
|
||||
},
|
||||
Equal: false,
|
||||
},
|
||||
{
|
||||
Name: "SameName",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: nil,
|
||||
GroupName: ptr.Ref("foo"),
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: nil,
|
||||
GroupName: ptr.Ref("foo"),
|
||||
},
|
||||
Equal: true,
|
||||
},
|
||||
{
|
||||
Name: "DifferentName",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: nil,
|
||||
GroupName: ptr.Ref("foo"),
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: nil,
|
||||
GroupName: ptr.Ref("bar"),
|
||||
},
|
||||
Equal: false,
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
// A bit strange, but valid as ID takes priority.
|
||||
// We assume 2 groups with the same ID are equal, even if
|
||||
// their names are different. Names are mutable, IDs are not,
|
||||
// so there is 0% chance they are different groups.
|
||||
Name: "DifferentIDSameName",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: ptr.Ref("foo"),
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: ptr.Ref("bar"),
|
||||
},
|
||||
Equal: true,
|
||||
},
|
||||
{
|
||||
Name: "MixedNils",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: nil,
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: nil,
|
||||
GroupName: ptr.Ref("bar"),
|
||||
},
|
||||
Equal: false,
|
||||
},
|
||||
{
|
||||
Name: "NoComparable",
|
||||
A: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: ptr.Ref(ids.ID("g1")),
|
||||
GroupName: nil,
|
||||
},
|
||||
B: idpsync.ExpectedGroup{
|
||||
OrganizationID: ids.ID("org"),
|
||||
GroupID: nil,
|
||||
GroupName: nil,
|
||||
},
|
||||
Equal: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, tc.Equal, tc.A.Equal(tc.B))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) {
|
||||
t.Helper()
|
||||
|
||||
// Account that the org might be the default organization
|
||||
org, err := db.GetOrganizationByID(context.Background(), orgID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
org = dbgen.Organization(t, db, database.Organization{
|
||||
ID: orgID,
|
||||
})
|
||||
}
|
||||
|
||||
_, err = db.InsertAllUsersGroup(context.Background(), org.ID)
|
||||
if !database.IsUniqueViolation(err) {
|
||||
require.NoError(t, err, "Everyone group for an org")
|
||||
}
|
||||
|
||||
manager := runtimeconfig.NewManager()
|
||||
orgResolver := manager.OrganizationResolver(db, org.ID)
|
||||
err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !def.NotMember {
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
}
|
||||
for groupID, in := range def.Groups {
|
||||
dbgen.Group(t, db, database.Group{
|
||||
ID: groupID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
if in {
|
||||
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
||||
UserID: user.ID,
|
||||
GroupID: groupID,
|
||||
})
|
||||
}
|
||||
}
|
||||
for groupName, in := range def.GroupNames {
|
||||
group := dbgen.Group(t, db, database.Group{
|
||||
Name: groupName,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
if in {
|
||||
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type orgSetupDefinition struct {
|
||||
Name string
|
||||
// True if the user is a member of the group
|
||||
Groups map[uuid.UUID]bool
|
||||
GroupNames map[string]bool
|
||||
NotMember bool
|
||||
|
||||
Settings *idpsync.GroupSyncSettings
|
||||
ExpectedGroups []uuid.UUID
|
||||
ExpectedGroupNames []string
|
||||
}
|
||||
|
||||
func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||
OrganizationID: orgID,
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if o.NotMember {
|
||||
require.Len(t, members, 0, "should not be a member")
|
||||
} else {
|
||||
require.Len(t, members, 1, "should be a member")
|
||||
}
|
||||
|
||||
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
||||
OrganizationID: orgID,
|
||||
HasMemberID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if o.ExpectedGroups == nil {
|
||||
o.ExpectedGroups = make([]uuid.UUID, 0)
|
||||
}
|
||||
if len(o.ExpectedGroupNames) > 0 && len(o.ExpectedGroups) > 0 {
|
||||
t.Fatal("ExpectedGroups and ExpectedGroupNames are mutually exclusive")
|
||||
}
|
||||
|
||||
// Everyone groups mess up our asserts
|
||||
userGroups = slices.DeleteFunc(userGroups, func(row database.GetGroupsRow) bool {
|
||||
return row.Group.ID == row.Group.OrganizationID
|
||||
})
|
||||
|
||||
if len(o.ExpectedGroupNames) > 0 {
|
||||
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) string {
|
||||
return g.Group.Name
|
||||
})
|
||||
require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name")
|
||||
require.Len(t, o.ExpectedGroups, 0, "ExpectedGroups should be empty")
|
||||
} else {
|
||||
// Check by ID, recommended
|
||||
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
|
||||
return g.Group.ID
|
||||
})
|
||||
require.ElementsMatch(t, o.ExpectedGroups, found, "user groups")
|
||||
require.Len(t, o.ExpectedGroupNames, 0, "ExpectedGroupNames should be empty")
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ package idpsync
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/site"
|
||||
)
|
||||
@ -25,21 +27,34 @@ type IDPSync interface {
|
||||
OrganizationSyncEnabled() bool
|
||||
// ParseOrganizationClaims takes claims from an OIDC provider, and returns the
|
||||
// organization sync params for assigning users into organizations.
|
||||
ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError)
|
||||
ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (OrganizationParams, *HTTPError)
|
||||
// SyncOrganizations assigns and removed users from organizations based on the
|
||||
// provided params.
|
||||
SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error
|
||||
|
||||
GroupSyncEnabled() bool
|
||||
// ParseGroupClaims takes claims from an OIDC provider, and returns the params
|
||||
// for group syncing. Most of the logic happens in SyncGroups.
|
||||
ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError)
|
||||
// SyncGroups assigns and removes users from groups based on the provided params.
|
||||
SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error
|
||||
// GroupSyncSettings is exposed for the API to implement CRUD operations
|
||||
// on the settings used by IDPSync. This entry is thread safe and can be
|
||||
// accessed concurrently. The settings are stored in the database.
|
||||
GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings]
|
||||
}
|
||||
|
||||
// AGPLIDPSync is the configuration for syncing user information from an external
|
||||
// IDP. All related code to syncing user information should be in this package.
|
||||
type AGPLIDPSync struct {
|
||||
Logger slog.Logger
|
||||
Logger slog.Logger
|
||||
Manager *runtimeconfig.Manager
|
||||
|
||||
SyncSettings
|
||||
}
|
||||
|
||||
type SyncSettings struct {
|
||||
// DeploymentSyncSettings are static and are sourced from the deployment config.
|
||||
type DeploymentSyncSettings struct {
|
||||
// OrganizationField selects the claim field to be used as the created user's
|
||||
// organizations. If the field is the empty string, then no organization updates
|
||||
// will ever come from the OIDC provider.
|
||||
@ -50,23 +65,62 @@ type SyncSettings struct {
|
||||
// placed into the default organization. This is mostly a hack to support
|
||||
// legacy deployments.
|
||||
OrganizationAssignDefault bool
|
||||
|
||||
// GroupField at the deployment level is used for deployment level group claim
|
||||
// settings.
|
||||
GroupField string
|
||||
// GroupAllowList (if set) will restrict authentication to only users who
|
||||
// have at least one group in this list.
|
||||
// A map representation is used for easier lookup.
|
||||
GroupAllowList map[string]struct{}
|
||||
// Legacy deployment settings that only apply to the default org.
|
||||
Legacy DefaultOrgLegacySettings
|
||||
}
|
||||
|
||||
type OrganizationParams struct {
|
||||
// SyncEnabled if false will skip syncing the user's organizations.
|
||||
SyncEnabled bool
|
||||
// IncludeDefault is primarily for single org deployments. It will ensure
|
||||
// a user is always inserted into the default org.
|
||||
IncludeDefault bool
|
||||
// Organizations is the list of organizations the user should be a member of
|
||||
// assuming syncing is turned on.
|
||||
Organizations []uuid.UUID
|
||||
type DefaultOrgLegacySettings struct {
|
||||
GroupField string
|
||||
GroupMapping map[string]string
|
||||
GroupFilter *regexp.Regexp
|
||||
CreateMissingGroups bool
|
||||
}
|
||||
|
||||
func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync {
|
||||
func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings {
|
||||
if dv == nil {
|
||||
panic("Developer error: DeploymentValues should not be nil")
|
||||
}
|
||||
return DeploymentSyncSettings{
|
||||
OrganizationField: dv.OIDC.OrganizationField.Value(),
|
||||
OrganizationMapping: dv.OIDC.OrganizationMapping.Value,
|
||||
OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(),
|
||||
|
||||
// TODO: Separate group field for allow list from default org.
|
||||
// Right now you cannot disable group sync from the default org and
|
||||
// configure an allow list.
|
||||
GroupField: dv.OIDC.GroupField.Value(),
|
||||
GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()),
|
||||
Legacy: DefaultOrgLegacySettings{
|
||||
GroupField: dv.OIDC.GroupField.Value(),
|
||||
GroupMapping: dv.OIDC.GroupMapping.Value,
|
||||
GroupFilter: dv.OIDC.GroupRegexFilter.Value(),
|
||||
CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type SyncSettings struct {
|
||||
DeploymentSyncSettings
|
||||
|
||||
Group runtimeconfig.RuntimeEntry[*GroupSyncSettings]
|
||||
}
|
||||
|
||||
func NewAGPLSync(logger slog.Logger, manager *runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync {
|
||||
return &AGPLIDPSync{
|
||||
Logger: logger.Named("idp-sync"),
|
||||
SyncSettings: settings,
|
||||
Logger: logger.Named("idp-sync"),
|
||||
Manager: manager,
|
||||
SyncSettings: SyncSettings{
|
||||
DeploymentSyncSettings: settings,
|
||||
Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,17 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
)
|
||||
|
||||
type OrganizationParams struct {
|
||||
// SyncEnabled if false will skip syncing the user's organizations.
|
||||
SyncEnabled bool
|
||||
// IncludeDefault is primarily for single org deployments. It will ensure
|
||||
// a user is always inserted into the default org.
|
||||
IncludeDefault bool
|
||||
// Organizations is the list of organizations the user should be a member of
|
||||
// assuming syncing is turned on.
|
||||
Organizations []uuid.UUID
|
||||
}
|
||||
|
||||
func (AGPLIDPSync) OrganizationSyncEnabled() bool {
|
||||
// AGPL does not support syncing organizations.
|
||||
return false
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/idpsync"
|
||||
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@ -18,11 +19,13 @@ func TestParseOrganizationClaims(t *testing.T) {
|
||||
t.Run("SingleOrgDeployment", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{
|
||||
OrganizationField: "",
|
||||
OrganizationMapping: nil,
|
||||
OrganizationAssignDefault: true,
|
||||
})
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||
runtimeconfig.NewManager(),
|
||||
idpsync.DeploymentSyncSettings{
|
||||
OrganizationField: "",
|
||||
OrganizationMapping: nil,
|
||||
OrganizationAssignDefault: true,
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
@ -38,13 +41,15 @@ func TestParseOrganizationClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// AGPL has limited behavior
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{
|
||||
OrganizationField: "orgs",
|
||||
OrganizationMapping: map[string][]uuid.UUID{
|
||||
"random": {uuid.New()},
|
||||
},
|
||||
OrganizationAssignDefault: false,
|
||||
})
|
||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||
runtimeconfig.NewManager(),
|
||||
idpsync.DeploymentSyncSettings{
|
||||
OrganizationField: "orgs",
|
||||
OrganizationMapping: map[string][]uuid.UUID{
|
||||
"random": {uuid.New()},
|
||||
},
|
||||
OrganizationAssignDefault: false,
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
|
@ -2,6 +2,7 @@ package runtimeconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@ -93,3 +94,11 @@ func (e *RuntimeEntry[T]) name() (string, error) {
|
||||
|
||||
return e.n, nil
|
||||
}
|
||||
|
||||
func JSONString(v any) string {
|
||||
s, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "decode failed: " + err.Error()
|
||||
}
|
||||
return string(s)
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -20,7 +19,6 @@ import (
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@ -659,6 +657,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
||||
AvatarURL: ghUser.GetAvatarURL(),
|
||||
Name: normName,
|
||||
DebugContext: OauthDebugContext{},
|
||||
GroupSync: idpsync.GroupParams{
|
||||
SyncEnabled: false,
|
||||
},
|
||||
OrganizationSync: idpsync.OrganizationParams{
|
||||
SyncEnabled: false,
|
||||
IncludeDefault: true,
|
||||
@ -739,27 +740,6 @@ type OIDCConfig struct {
|
||||
// support the userinfo endpoint, or if the userinfo endpoint causes
|
||||
// undesirable behavior.
|
||||
IgnoreUserInfo bool
|
||||
|
||||
// TODO: Move all idp fields into the IDPSync struct
|
||||
// GroupField selects the claim field to be used as the created user's
|
||||
// groups. If the group field is the empty string, then no group updates
|
||||
// will ever come from the OIDC provider.
|
||||
GroupField string
|
||||
// CreateMissingGroups controls whether groups returned by the OIDC provider
|
||||
// are automatically created in Coder if they are missing.
|
||||
CreateMissingGroups bool
|
||||
// GroupFilter is a regular expression that filters the groups returned by
|
||||
// the OIDC provider. Any group not matched by this regex will be ignored.
|
||||
// If the group filter is nil, then no group filtering will occur.
|
||||
GroupFilter *regexp.Regexp
|
||||
// GroupAllowList is a list of groups that are allowed to log in.
|
||||
// If the list length is 0, then the allow list will not be applied and
|
||||
// this feature is disabled.
|
||||
GroupAllowList map[string]bool
|
||||
// GroupMapping controls how groups returned by the OIDC provider get mapped
|
||||
// to groups within Coder.
|
||||
// map[oidcGroupName]coderGroupName
|
||||
GroupMapping map[string]string
|
||||
// UserRoleField selects the claim field to be used as the created user's
|
||||
// roles. If the field is the empty string, then no role updates
|
||||
// will ever come from the OIDC provider.
|
||||
@ -1002,11 +982,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name))
|
||||
usingGroups, groups, groupErr := api.oidcGroups(ctx, mergedClaims)
|
||||
if groupErr != nil {
|
||||
groupErr.Write(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
roles, roleErr := api.oidcRoles(ctx, mergedClaims)
|
||||
if roleErr != nil {
|
||||
@ -1030,6 +1005,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
groupSync, groupSyncErr := api.IDPSync.ParseGroupClaims(ctx, mergedClaims)
|
||||
if groupSyncErr != nil {
|
||||
groupSyncErr.Write(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
// If a new user is authenticating for the first time
|
||||
// the audit action is 'register', not 'login'
|
||||
if user.ID == uuid.Nil {
|
||||
@ -1037,23 +1018,20 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
params := (&oauthLoginParams{
|
||||
User: user,
|
||||
Link: link,
|
||||
State: state,
|
||||
LinkedID: oidcLinkedID(idToken),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
AllowSignups: api.OIDCConfig.AllowSignups,
|
||||
Email: email,
|
||||
Username: username,
|
||||
Name: name,
|
||||
AvatarURL: picture,
|
||||
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
|
||||
Roles: roles,
|
||||
UsingGroups: usingGroups,
|
||||
Groups: groups,
|
||||
OrganizationSync: orgSync,
|
||||
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
|
||||
GroupFilter: api.OIDCConfig.GroupFilter,
|
||||
User: user,
|
||||
Link: link,
|
||||
State: state,
|
||||
LinkedID: oidcLinkedID(idToken),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
AllowSignups: api.OIDCConfig.AllowSignups,
|
||||
Email: email,
|
||||
Username: username,
|
||||
Name: name,
|
||||
AvatarURL: picture,
|
||||
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
|
||||
Roles: roles,
|
||||
OrganizationSync: orgSync,
|
||||
GroupSync: groupSync,
|
||||
DebugContext: OauthDebugContext{
|
||||
IDTokenClaims: idtokenClaims,
|
||||
UserInfoClaims: userInfoClaims,
|
||||
@ -1089,79 +1067,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// oidcGroups returns the groups for the user from the OIDC claims.
|
||||
func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interface{}) (bool, []string, *idpsync.HTTPError) {
|
||||
logger := api.Logger.Named(userAuthLoggerName)
|
||||
usingGroups := false
|
||||
var groups []string
|
||||
|
||||
// If the GroupField is the empty string, then groups from OIDC are not used.
|
||||
// This is so we can support manual group assignment.
|
||||
if api.OIDCConfig.GroupField != "" {
|
||||
// If the allow list is empty, then the user is allowed to log in.
|
||||
// Otherwise, they must belong to at least 1 group in the allow list.
|
||||
inAllowList := len(api.OIDCConfig.GroupAllowList) == 0
|
||||
|
||||
usingGroups = true
|
||||
groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField]
|
||||
if ok {
|
||||
parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw)
|
||||
if err != nil {
|
||||
api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims",
|
||||
slog.F("type", fmt.Sprintf("%T", groupsRaw)),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false, nil, &idpsync.HTTPError{
|
||||
Code: http.StatusBadRequest,
|
||||
Msg: "Failed to sync groups from OIDC claims",
|
||||
Detail: err.Error(),
|
||||
RenderStaticPage: false,
|
||||
}
|
||||
}
|
||||
|
||||
api.Logger.Debug(ctx, "groups returned in oidc claims",
|
||||
slog.F("len", len(parsedGroups)),
|
||||
slog.F("groups", parsedGroups),
|
||||
)
|
||||
|
||||
for _, group := range parsedGroups {
|
||||
if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok {
|
||||
group = mappedGroup
|
||||
}
|
||||
if _, ok := api.OIDCConfig.GroupAllowList[group]; ok {
|
||||
inAllowList = true
|
||||
}
|
||||
groups = append(groups, group)
|
||||
}
|
||||
}
|
||||
|
||||
if !inAllowList {
|
||||
logger.Debug(ctx, "oidc group claim not in allow list, rejecting login",
|
||||
slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)),
|
||||
slog.F("user_group_count", len(groups)),
|
||||
)
|
||||
detail := "Ask an administrator to add one of your groups to the whitelist"
|
||||
if len(groups) == 0 {
|
||||
detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login."
|
||||
}
|
||||
return usingGroups, groups, &idpsync.HTTPError{
|
||||
Code: http.StatusForbidden,
|
||||
Msg: "Not a member of an allowed group",
|
||||
Detail: detail,
|
||||
RenderStaticPage: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This conditional is purely to warn the user they might have misconfigured their OIDC
|
||||
// configuration.
|
||||
if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists {
|
||||
logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings")
|
||||
}
|
||||
|
||||
return usingGroups, groups, nil
|
||||
}
|
||||
|
||||
// oidcRoles returns the roles for the user from the OIDC claims.
|
||||
// If the function returns false, then the caller should return early.
|
||||
// All writes to the response writer are handled by this function.
|
||||
@ -1276,14 +1181,7 @@ type oauthLoginParams struct {
|
||||
AvatarURL string
|
||||
// OrganizationSync has the organizations that the user will be assigned to.
|
||||
OrganizationSync idpsync.OrganizationParams
|
||||
// Is UsingGroups is true, then the user will be assigned
|
||||
// to the Groups provided.
|
||||
UsingGroups bool
|
||||
CreateMissingGroups bool
|
||||
// These are the group names from the IDP. Internally, they will map to
|
||||
// some organization groups.
|
||||
Groups []string
|
||||
GroupFilter *regexp.Regexp
|
||||
GroupSync idpsync.GroupParams
|
||||
// Is UsingRoles is true, then the user will be assigned
|
||||
// the roles provided.
|
||||
UsingRoles bool
|
||||
@ -1489,53 +1387,11 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
||||
return xerrors.Errorf("sync organizations: %w", err)
|
||||
}
|
||||
|
||||
// Ensure groups are correct.
|
||||
// This places all groups into the default organization.
|
||||
// To go multi-org, we need to add a mapping feature here to know which
|
||||
// groups go to which orgs.
|
||||
if params.UsingGroups {
|
||||
filtered := params.Groups
|
||||
if params.GroupFilter != nil {
|
||||
filtered = make([]string, 0, len(params.Groups))
|
||||
for _, group := range params.Groups {
|
||||
if params.GroupFilter.MatchString(group) {
|
||||
filtered = append(filtered, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocritic // No user present in the context.
|
||||
defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil {
|
||||
// If there is no default org, then we can't assign groups.
|
||||
// By default, we assume all groups belong to the default org.
|
||||
return xerrors.Errorf("get default organization: %w", err)
|
||||
}
|
||||
|
||||
//nolint:gocritic // No user present in the context.
|
||||
memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{
|
||||
UserID: user.ID,
|
||||
OrganizationID: uuid.Nil,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get organization memberships: %w", err)
|
||||
}
|
||||
|
||||
// If the user is not in the default organization, then we can't assign groups.
|
||||
// A user cannot be in groups to an org they are not a member of.
|
||||
if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool {
|
||||
return member.OrganizationMember.OrganizationID == defaultOrganization.ID
|
||||
}) {
|
||||
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{
|
||||
defaultOrganization.ID: filtered,
|
||||
}, params.CreateMissingGroups)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set user groups: %w", err)
|
||||
}
|
||||
// Group sync needs to occur after org sync, since a user can join an org,
|
||||
// then have their groups sync to said org.
|
||||
err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("sync groups: %w", err)
|
||||
}
|
||||
|
||||
// Ensure roles are correct.
|
||||
|
Reference in New Issue
Block a user