chore: support multi-org group sync with runtime configuration (#14578)

- Implement multi-org group sync
- Implement runtime configuration to change sync behavior
- Legacy group sync migrated to new package
This commit is contained in:
Steven Masley
2024-09-11 13:43:50 -05:00
committed by GitHub
parent 7de576b596
commit 6a846cdbb8
27 changed files with 1920 additions and 341 deletions

View File

@ -181,7 +181,6 @@ type Options struct {
NetworkTelemetryBatchFrequency time.Duration
NetworkTelemetryBatchMaxSize int
SwaggerEndpoint bool
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
@ -276,13 +275,6 @@ func New(options *Options) *API {
if options.Entitlements == nil {
options.Entitlements = entitlements.New()
}
if options.IDPSync == nil {
options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{
OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(),
OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value,
OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(),
})
}
if options.NewTicker == nil {
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
ticker := time.NewTicker(duration)
@ -318,6 +310,10 @@ func New(options *Options) *API {
options.AccessControlStore,
)
if options.IDPSync == nil {
options.IDPSync = idpsync.NewAGPLSync(options.Logger, options.RuntimeConfig, idpsync.FromDeploymentValues(options.DeploymentValues))
}
experiments := ReadExperiments(
options.Logger, options.DeploymentValues.Experiments.Value(),
)
@ -377,16 +373,6 @@ func New(options *Options) *API {
if options.TracerProvider == nil {
options.TracerProvider = trace.NewNoopTracerProvider()
}
if options.SetUserGroups == nil {
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
slog.F("user_id", userID),
slog.F("groups", orgGroupNames),
slog.F("create_missing_groups", createMissingGroups),
)
return nil
}
}
if options.SetUserSiteRoles == nil {
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",

View File

@ -0,0 +1,25 @@
package coderdtest
import "github.com/google/uuid"
// DeterministicUUIDGenerator allows "naming" uuids for unit tests.
// An example of where this is useful, is when a tabled test references
// a UUID that is not yet known. An alternative to this would be to
// hard code some UUID strings, but these strings are not human friendly.
type DeterministicUUIDGenerator struct {
Named map[string]uuid.UUID
}
func NewDeterministicUUIDGenerator() *DeterministicUUIDGenerator {
return &DeterministicUUIDGenerator{
Named: make(map[string]uuid.UUID),
}
}
func (d *DeterministicUUIDGenerator) ID(name string) uuid.UUID {
if v, ok := d.Named[name]; ok {
return v
}
d.Named[name] = uuid.New()
return d.Named[name]
}

View File

@ -0,0 +1,17 @@
package coderdtest_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
)
func TestDeterministicUUIDGenerator(t *testing.T) {
t.Parallel()
ids := coderdtest.NewDeterministicUUIDGenerator()
require.Equal(t, ids.ID("g1"), ids.ID("g1"))
require.NotEqual(t, ids.ID("g1"), ids.ID("g2"))
}

View File

@ -2892,6 +2892,14 @@ func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams)
return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg)
}
func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
// This is used by OIDC sync. So only used by a system user.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.InsertUserGroupsByID(ctx, arg)
}
func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
// This will add the user to all named groups. This counts as updating a group.
// NOTE: instead of checking if the user has permission to update each group, we instead
@ -3100,6 +3108,14 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID)
return q.db.RemoveUserFromAllGroups(ctx, userID)
}
func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
// This is a system function to clear user groups in group sync.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.RemoveUserFromGroups(ctx, arg)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err

View File

@ -388,6 +388,17 @@ func (s *MethodTestSuite) TestGroup() {
GroupNames: slice.New(g1.Name, g2.Name),
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns()
}))
s.Run("InsertUserGroupsByID", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u1 := dbgen.User(s.T(), db, database.User{})
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
check.Args(database.InsertUserGroupsByIDParams{
UserID: u1.ID,
GroupIds: slice.New(g1.ID, g2.ID),
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
}))
s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u1 := dbgen.User(s.T(), db, database.User{})
@ -397,6 +408,18 @@ func (s *MethodTestSuite) TestGroup() {
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns()
}))
s.Run("RemoveUserFromGroups", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u1 := dbgen.User(s.T(), db, database.User{})
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
check.Args(database.RemoveUserFromGroupsParams{
UserID: u1.ID,
GroupIds: []uuid.UUID{g1.ID, g2.ID},
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
}))
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {
g := dbgen.Group(s.T(), db, database.Group{})
check.Args(database.UpdateGroupByIDParams{

View File

@ -2695,18 +2695,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
q.mutex.RLock()
defer q.mutex.RUnlock()
groupIDs := make(map[uuid.UUID]struct{})
userGroupIDs := make(map[uuid.UUID]struct{})
if arg.HasMemberID != uuid.Nil {
for _, member := range q.groupMembers {
if member.UserID == arg.HasMemberID {
groupIDs[member.GroupID] = struct{}{}
userGroupIDs[member.GroupID] = struct{}{}
}
}
// Handle the everyone group
for _, orgMember := range q.organizationMembers {
if orgMember.UserID == arg.HasMemberID {
groupIDs[orgMember.OrganizationID] = struct{}{}
userGroupIDs[orgMember.OrganizationID] = struct{}{}
}
}
}
@ -2718,11 +2718,15 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
continue
}
_, ok := groupIDs[group.ID]
_, ok := userGroupIDs[group.ID]
if arg.HasMemberID != uuid.Nil && !ok {
continue
}
if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) {
continue
}
orgDetails, ok := orgDetailsCache[group.ID]
if !ok {
for _, org := range q.organizations {
@ -7015,7 +7019,37 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
return user, nil
}
func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
}
q.mutex.Lock()
defer q.mutex.Unlock()
var groupIDs []uuid.UUID
for _, group := range q.groups {
for _, groupID := range arg.GroupIds {
if group.ID == groupID {
q.groupMembers = append(q.groupMembers, database.GroupMemberTable{
UserID: arg.UserID,
GroupID: groupID,
})
groupIDs = append(groupIDs, group.ID)
}
}
}
return groupIDs, nil
}
func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}
q.mutex.Lock()
defer q.mutex.Unlock()
@ -7607,6 +7641,34 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI
return nil
}
func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
}
q.mutex.Lock()
defer q.mutex.Unlock()
removed := make([]uuid.UUID, 0)
q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool {
// Delete all group members that match the arguments.
if groupMember.UserID != arg.UserID {
// Not the right user, ignore.
return false
}
if !slices.Contains(arg.GroupIds, groupMember.GroupID) {
return false
}
removed = append(removed, groupMember.GroupID)
return true
})
return removed, nil
}
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -1789,6 +1789,13 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar
return user, err
}
func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.InsertUserGroupsByID(ctx, arg)
m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
start := time.Now()
err := m.s.InsertUserGroupsByName(ctx, arg)
@ -1943,6 +1950,13 @@ func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.U
return r0
}
func (m metricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.RemoveUserFromGroups(ctx, arg)
m.queryLatencies.WithLabelValues("RemoveUserFromGroups").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)

View File

@ -3766,6 +3766,21 @@ func (mr *MockStoreMockRecorder) InsertUser(arg0, arg1 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1)
}
// InsertUserGroupsByID mocks base method.
func (m *MockStore) InsertUserGroupsByID(arg0 context.Context, arg1 database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertUserGroupsByID", arg0, arg1)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertUserGroupsByID indicates an expected call of InsertUserGroupsByID.
func (mr *MockStoreMockRecorder) InsertUserGroupsByID(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), arg0, arg1)
}
// InsertUserGroupsByName mocks base method.
func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error {
m.ctrl.T.Helper()
@ -4103,6 +4118,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1)
}
// RemoveUserFromGroups mocks base method.
func (m *MockStore) RemoveUserFromGroups(arg0 context.Context, arg1 database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveUserFromGroups", arg0, arg1)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RemoveUserFromGroups indicates an expected call of RemoveUserFromGroups.
func (mr *MockStoreMockRecorder) RemoveUserFromGroups(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), arg0, arg1)
}
// RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()

View File

@ -369,6 +369,9 @@ type sqlcQuerier interface {
InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error)
InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
// If there is a conflict, the user is already a member
InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error)
// InsertUserGroupsByName adds a user to all provided groups, if they exist.
InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
@ -396,6 +399,7 @@ type sqlcQuerier interface {
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
//

View File

@ -1446,6 +1446,56 @@ func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMembe
return err
}
const insertUserGroupsByID = `-- name: InsertUserGroupsByID :many
WITH groups AS (
SELECT
id
FROM
groups
WHERE
groups.id = ANY($2 :: uuid [])
)
INSERT INTO
group_members (user_id, group_id)
SELECT
$1,
groups.id
FROM
groups
ON CONFLICT DO NOTHING
RETURNING group_id
`
type InsertUserGroupsByIDParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
}
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
// If there is a conflict, the user is already a member
func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var group_id uuid.UUID
if err := rows.Scan(&group_id); err != nil {
return nil, err
}
items = append(items, group_id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec
WITH groups AS (
SELECT
@ -1489,6 +1539,43 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU
return err
}
const removeUserFromGroups = `-- name: RemoveUserFromGroups :many
DELETE FROM
group_members
WHERE
user_id = $1 AND
group_id = ANY($2 :: uuid [])
RETURNING group_id
`
type RemoveUserFromGroupsParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
}
func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var group_id uuid.UUID
if err := rows.Scan(&group_id); err != nil {
return nil, err
}
items = append(items, group_id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteGroupByID = `-- name: DeleteGroupByID :exec
DELETE FROM
groups
@ -1592,11 +1679,16 @@ WHERE
)
ELSE true
END
AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN
groups.name = ANY($3)
ELSE true
END
`
type GetGroupsParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"`
GroupNames []string `db:"group_names" json:"group_names"`
}
type GetGroupsRow struct {
@ -1606,7 +1698,7 @@ type GetGroupsRow struct {
}
func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) {
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID)
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID, pq.Array(arg.GroupNames))
if err != nil {
return nil, err
}

View File

@ -29,12 +29,41 @@ SELECT
FROM
groups;
-- InsertUserGroupsByID adds a user to all provided groups, if they exist.
-- name: InsertUserGroupsByID :many
WITH groups AS (
SELECT
id
FROM
groups
WHERE
groups.id = ANY(@group_ids :: uuid [])
)
INSERT INTO
group_members (user_id, group_id)
SELECT
@user_id,
groups.id
FROM
groups
-- If there is a conflict, the user is already a member
ON CONFLICT DO NOTHING
RETURNING group_id;
-- name: RemoveUserFromAllGroups :exec
DELETE FROM
group_members
WHERE
user_id = @user_id;
-- name: RemoveUserFromGroups :many
DELETE FROM
group_members
WHERE
user_id = @user_id AND
group_id = ANY(@group_ids :: uuid [])
RETURNING group_id;
-- name: InsertGroupMember :exec
INSERT INTO
group_members (user_id, group_id)

View File

@ -52,6 +52,10 @@ WHERE
)
ELSE true
END
AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN
groups.name = ANY(@group_names)
ELSE true
END
;
-- name: InsertGroup :one

416
coderd/idpsync/group.go Normal file
View File

@ -0,0 +1,416 @@
package idpsync
import (
"context"
"encoding/json"
"fmt"
"regexp"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/coderd/util/slice"
)
type GroupParams struct {
// SyncEnabled if false will skip syncing the user's groups
SyncEnabled bool
MergedClaims jwt.MapClaims
}
func (AGPLIDPSync) GroupSyncEnabled() bool {
// AGPL does not support syncing groups.
return false
}
func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] {
return s.Group
}
func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) {
return GroupParams{
SyncEnabled: s.GroupSyncEnabled(),
}, nil
}
func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error {
// Nothing happens if sync is not enabled
if !params.SyncEnabled {
return nil
}
// nolint:gocritic // all syncing is done as a system user
ctx = dbauthz.AsSystemRestricted(ctx)
// Only care about the default org for deployment settings if the
// legacy deployment settings exist.
defaultOrgID := uuid.Nil
// Default organization is configured via legacy deployment values
if s.DeploymentSyncSettings.Legacy.GroupField != "" {
defaultOrganization, err := db.GetDefaultOrganization(ctx)
if err != nil {
return xerrors.Errorf("get default organization: %w", err)
}
defaultOrgID = defaultOrganization.ID
}
err := db.InTx(func(tx database.Store) error {
userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
HasMemberID: user.ID,
})
if err != nil {
return xerrors.Errorf("get user groups: %w", err)
}
// Figure out which organizations the user is a member of.
// The "Everyone" group is always included, so we can infer organization
// membership via the groups the user is in.
userOrgs := make(map[uuid.UUID][]database.GetGroupsRow)
for _, g := range userGroups {
g := g
userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g)
}
// For each org, we need to fetch the sync settings
// This loop also handles any legacy settings for the default
// organization.
orgSettings := make(map[uuid.UUID]GroupSyncSettings)
for orgID := range userOrgs {
orgResolver := s.Manager.OrganizationResolver(tx, orgID)
settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver)
if err != nil {
if !xerrors.Is(err, runtimeconfig.ErrEntryNotFound) {
return xerrors.Errorf("resolve group sync settings: %w", err)
}
// Default to not being configured
settings = &GroupSyncSettings{}
}
// Legacy deployment settings will override empty settings.
if orgID == defaultOrgID && settings.Field == "" {
settings = &GroupSyncSettings{
Field: s.Legacy.GroupField,
LegacyNameMapping: s.Legacy.GroupMapping,
RegexFilter: s.Legacy.GroupFilter,
AutoCreateMissing: s.Legacy.CreateMissingGroups,
}
}
orgSettings[orgID] = *settings
}
// groupIDsToAdd & groupIDsToRemove are the final group differences
// needed to be applied to user. The loop below will iterate over all
// organizations the user is in, and determine the diffs.
// The diffs are applied as a batch sql query, rather than each
// organization having to execute a query.
groupIDsToAdd := make([]uuid.UUID, 0)
groupIDsToRemove := make([]uuid.UUID, 0)
// For each org, determine which groups the user should land in
for orgID, settings := range orgSettings {
if settings.Field == "" {
// No group sync enabled for this org, so do nothing.
// The user can remain in their groups for this org.
continue
}
// expectedGroups is the set of groups the IDP expects the
// user to be a member of.
expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims)
if err != nil {
s.Logger.Debug(ctx, "failed to parse claims for groups",
slog.F("organization_field", s.GroupField),
slog.F("organization_id", orgID),
slog.Error(err),
)
// Unsure where to raise this error on the UI or database.
// TODO: This error prevents group sync, but we have no way
// to raise this to an org admin. Come up with a solution to
// notify the admin and user of this issue.
continue
}
// Everyone group is always implied, so include it.
expectedGroups = append(expectedGroups, ExpectedGroup{
OrganizationID: orgID,
GroupID: &orgID,
})
// Now we know what groups the user should be in for a given org,
// determine if we have to do any group updates to sync the user's
// state.
existingGroups := userOrgs[orgID]
existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup {
return ExpectedGroup{
OrganizationID: orgID,
GroupID: &f.Group.ID,
GroupName: &f.Group.Name,
}
})
add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool {
return a.Equal(b)
})
for _, r := range remove {
if r.GroupID == nil {
// This should never happen. All group removals come from the
// existing set, which come from the db. All groups from the
// database have IDs. This code is purely defensive.
detail := "user:" + user.Username
if r.GroupName != nil {
detail += fmt.Sprintf(" from group %s", *r.GroupName)
}
return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail)
}
groupIDsToRemove = append(groupIDsToRemove, *r.GroupID)
}
// HandleMissingGroups will add the new groups to the org if
// the settings specify. It will convert all group names into uuids
// for easier assignment.
// TODO: This code should be batched at the end of the for loop.
// Optimizing this is being pushed because if AutoCreate is disabled,
// this code will only add cost on the first login for each user.
// AutoCreate is usually disabled for large deployments.
// For small deployments, this is less of a problem.
assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add)
if err != nil {
return xerrors.Errorf("handle missing groups: %w", err)
}
groupIDsToAdd = append(groupIDsToAdd, assignGroups...)
}
// ApplyGroupDifference will take the total adds and removes, and apply
// them.
err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove)
if err != nil {
return xerrors.Errorf("apply group difference: %w", err)
}
return nil
}, nil)
if err != nil {
return err
}
return nil
}
// ApplyGroupDifference will add and remove the user from the specified groups.
func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error {
if len(removeIDs) > 0 {
removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{
UserID: user.ID,
GroupIds: removeIDs,
})
if err != nil {
return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err)
}
if len(removedGroupIDs) != len(removeIDs) {
s.Logger.Debug(ctx, "user not removed from expected number of groups",
slog.F("user_id", user.ID),
slog.F("groups_removed_count", len(removedGroupIDs)),
slog.F("expected_count", len(removeIDs)),
)
}
}
if len(add) > 0 {
add = slice.Unique(add)
// Defensive programming to only insert uniques.
assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{
UserID: user.ID,
GroupIds: add,
})
if err != nil {
return xerrors.Errorf("insert user into %d groups: %w", len(add), err)
}
if len(assignedGroupIDs) != len(add) {
s.Logger.Debug(ctx, "user not assigned to expected number of groups",
slog.F("user_id", user.ID),
slog.F("groups_assigned_count", len(assignedGroupIDs)),
slog.F("expected_count", len(add)),
)
}
}
return nil
}
type GroupSyncSettings struct {
// Field selects the claim field to be used as the created user's
// groups. If the group field is the empty string, then no group updates
// will ever come from the OIDC provider.
Field string `json:"field"`
// Mapping maps from an OIDC group --> Coder group ID
Mapping map[string][]uuid.UUID `json:"mapping"`
// RegexFilter is a regular expression that filters the groups returned by
// the OIDC provider. Any group not matched by this regex will be ignored.
// If the group filter is nil, then no group filtering will occur.
RegexFilter *regexp.Regexp `json:"regex_filter"`
// AutoCreateMissing controls whether groups returned by the OIDC provider
// are automatically created in Coder if they are missing.
AutoCreateMissing bool `json:"auto_create_missing_groups"`
// LegacyNameMapping is deprecated. It remaps an IDP group name to
// a Coder group name. Since configuration is now done at runtime,
// group IDs are used to account for group renames.
// For legacy configurations, this config option has to remain.
// Deprecated: Use Mapping instead.
LegacyNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"`
}
func (s *GroupSyncSettings) Set(v string) error {
return json.Unmarshal([]byte(v), s)
}
func (s *GroupSyncSettings) String() string {
return runtimeconfig.JSONString(s)
}
type ExpectedGroup struct {
OrganizationID uuid.UUID
GroupID *uuid.UUID
GroupName *string
}
// Equal compares two ExpectedGroups. The org id must be the same.
// If the group ID is set, it will be compared and take priority, ignoring the
// name value. So 2 groups with the same ID but different names will be
// considered equal.
func (a ExpectedGroup) Equal(b ExpectedGroup) bool {
// Must match
if a.OrganizationID != b.OrganizationID {
return false
}
// Only the name or the name needs to be checked, priority is given to the ID.
if a.GroupID != nil && b.GroupID != nil {
return *a.GroupID == *b.GroupID
}
if a.GroupName != nil && b.GroupName != nil {
return *a.GroupName == *b.GroupName
}
// If everything is nil, it is equal. Although a bit pointless
if a.GroupID == nil && b.GroupID == nil &&
a.GroupName == nil && b.GroupName == nil {
return true
}
return false
}
// ParseClaims will take the merged claims from the IDP and return the groups
// the user is expected to be a member of. The expected group can either be a
// name or an ID.
// It is unfortunate we cannot use exclusively names or exclusively IDs.
// When configuring though, if a group is mapped from "A" -> "UUID 1234", and
// the group "UUID 1234" is renamed, we want to maintain the mapping.
// We have to keep names because group sync supports syncing groups by name if
// the external IDP group name matches the Coder one.
func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) {
groupsRaw, ok := mergedClaims[s.Field]
if !ok {
return []ExpectedGroup{}, nil
}
parsedGroups, err := ParseStringSliceClaim(groupsRaw)
if err != nil {
return nil, xerrors.Errorf("parse groups field, unexpected type %T: %w", groupsRaw, err)
}
groups := make([]ExpectedGroup, 0)
for _, group := range parsedGroups {
group := group
// Legacy group mappings happen before the regex filter.
mappedGroupName, ok := s.LegacyNameMapping[group]
if ok {
group = mappedGroupName
}
// Only allow through groups that pass the regex
if s.RegexFilter != nil {
if !s.RegexFilter.MatchString(group) {
continue
}
}
mappedGroupIDs, ok := s.Mapping[group]
if ok {
for _, gid := range mappedGroupIDs {
gid := gid
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid})
}
continue
}
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group})
}
return groups, nil
}
// HandleMissingGroups ensures all ExpectedGroups convert to uuids.
// Groups can be referenced by name via legacy params or IDP group names.
// These group names are converted to IDs for easier assignment.
// Missing groups are created if AutoCreate is enabled.
// TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) {
// All expected that are missing IDs means the group does not exist
// in the database, or it is a legacy mapping, and we need to do a lookup.
var missingGroups []string
addIDs := make([]uuid.UUID, 0)
for _, expected := range add {
if expected.GroupID == nil && expected.GroupName != nil {
missingGroups = append(missingGroups, *expected.GroupName)
} else if expected.GroupID != nil {
// Keep the IDs to sync the groups.
addIDs = append(addIDs, *expected.GroupID)
}
}
if s.AutoCreateMissing && len(missingGroups) > 0 {
// Insert any missing groups. If the groups already exist, this is a noop.
_, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{
OrganizationID: orgID,
Source: database.GroupSourceOidc,
GroupNames: missingGroups,
})
if err != nil {
return nil, xerrors.Errorf("insert missing groups: %w", err)
}
}
// Fetch any missing groups by name. If they exist, their IDs will be
// matched and returned.
if len(missingGroups) > 0 {
// Do name lookups for all groups that are missing IDs.
newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
OrganizationID: orgID,
HasMemberID: uuid.UUID{},
GroupNames: missingGroups,
})
if err != nil {
return nil, xerrors.Errorf("get groups by names: %w", err)
}
for _, g := range newGroups {
addIDs = append(addIDs, g.Group.ID)
}
}
return addIDs, nil
}
func ConvertAllowList(allowList []string) map[string]struct{} {
allowMap := make(map[string]struct{}, len(allowList))
for _, group := range allowList {
allowMap[group] = struct{}{}
}
return allowMap
}

View File

@ -0,0 +1,814 @@
package idpsync_test
import (
"context"
"database/sql"
"regexp"
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/testutil"
)
func TestParseGroupClaims(t *testing.T) {
t.Parallel()
t.Run("EmptyConfig", func(t *testing.T) {
t.Parallel()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{})
ctx := testutil.Context(t, testutil.WaitMedium)
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
require.Nil(t, err)
require.False(t, params.SyncEnabled)
})
// AllowList has no effect in AGPL
t.Run("AllowList", func(t *testing.T) {
t.Parallel()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{
GroupField: "groups",
GroupAllowList: map[string]struct{}{
"foo": {},
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
require.Nil(t, err)
require.False(t, params.SyncEnabled)
})
}
func TestGroupSyncTable(t *testing.T) {
t.Parallel()
// Last checked, takes 30s with postgres on a fast machine.
if dbtestutil.WillUsePostgres() {
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
}
userClaims := jwt.MapClaims{
"groups": []string{
"foo", "bar", "baz",
"create-bar", "create-baz",
"legacy-bar",
},
}
ids := coderdtest.NewDeterministicUUIDGenerator()
testCases := []orgSetupDefinition{
{
Name: "SwitchGroups",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")},
"bar": {ids.ID("sg-bar")},
"baz": {ids.ID("sg-baz")},
},
},
Groups: map[uuid.UUID]bool{
uuid.New(): true,
uuid.New(): true,
// Extra groups
ids.ID("sg-foo"): false,
ids.ID("sg-foo-2"): false,
ids.ID("sg-bar"): false,
ids.ID("sg-baz"): false,
},
ExpectedGroups: []uuid.UUID{
ids.ID("sg-foo"),
ids.ID("sg-foo-2"),
ids.ID("sg-bar"),
ids.ID("sg-baz"),
},
},
{
Name: "StayInGroup",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
// Only match foo, so bar does not map
RegexFilter: regexp.MustCompile("^foo$"),
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("gg-foo"), uuid.New()},
"bar": {ids.ID("gg-bar")},
"baz": {ids.ID("gg-baz")},
},
},
Groups: map[uuid.UUID]bool{
ids.ID("gg-foo"): true,
ids.ID("gg-bar"): false,
},
ExpectedGroups: []uuid.UUID{
ids.ID("gg-foo"),
},
},
{
Name: "UserJoinsGroups",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("ng-foo"), uuid.New()},
"bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")},
"baz": {ids.ID("ng-baz")},
},
},
Groups: map[uuid.UUID]bool{
ids.ID("ng-foo"): false,
ids.ID("ng-bar"): false,
ids.ID("ng-bar-2"): false,
ids.ID("ng-baz"): false,
},
ExpectedGroups: []uuid.UUID{
ids.ID("ng-foo"),
ids.ID("ng-bar"),
ids.ID("ng-bar-2"),
ids.ID("ng-baz"),
},
},
{
Name: "CreateGroups",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
RegexFilter: regexp.MustCompile("^create"),
AutoCreateMissing: true,
},
Groups: map[uuid.UUID]bool{},
ExpectedGroupNames: []string{
"create-bar",
"create-baz",
},
},
{
Name: "GroupNamesNoMapping",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
RegexFilter: regexp.MustCompile(".*"),
AutoCreateMissing: false,
},
GroupNames: map[string]bool{
"foo": false,
"bar": false,
"goob": true,
},
ExpectedGroupNames: []string{
"foo",
"bar",
},
},
{
Name: "NoUser",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
// Extra ID that does not map to a group
"foo": {ids.ID("ow-foo"), uuid.New()},
},
RegexFilter: nil,
AutoCreateMissing: false,
},
NotMember: true,
Groups: map[uuid.UUID]bool{
ids.ID("ow-foo"): false,
ids.ID("ow-bar"): false,
},
},
{
Name: "NoSettingsNoUser",
Settings: nil,
Groups: map[uuid.UUID]bool{},
},
{
Name: "LegacyMapping",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
RegexFilter: regexp.MustCompile("^legacy"),
LegacyNameMapping: map[string]string{
"create-bar": "legacy-bar",
"foo": "legacy-foo",
"bop": "legacy-bop",
},
AutoCreateMissing: true,
},
Groups: map[uuid.UUID]bool{
ids.ID("lg-foo"): true,
},
GroupNames: map[string]bool{
"legacy-foo": false,
"extra": true,
"legacy-bop": true,
},
ExpectedGroupNames: []string{
"legacy-bar",
"legacy-foo",
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
manager := runtimeconfig.NewManager()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
manager,
idpsync.DeploymentSyncSettings{
GroupField: "groups",
},
)
ctx := testutil.Context(t, testutil.WaitSuperLong)
user := dbgen.User(t, db, database.User{})
orgID := uuid.New()
SetupOrganization(t, s, db, user, orgID, tc)
// Do the group sync!
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
SyncEnabled: true,
MergedClaims: userClaims,
})
require.NoError(t, err)
tc.Assert(t, orgID, db, user)
})
}
// AllTogether runs the entire tabled test as a singular user and
// deployment. This tests all organizations being synced together.
// The reason we do them individually, is that it is much easier to
// debug a single test case.
t.Run("AllTogether", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
manager := runtimeconfig.NewManager()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
manager,
// Also sync the default org!
idpsync.DeploymentSyncSettings{
GroupField: "groups",
Legacy: idpsync.DefaultOrgLegacySettings{
GroupField: "groups",
GroupMapping: map[string]string{
"foo": "legacy-foo",
"baz": "legacy-baz",
},
GroupFilter: regexp.MustCompile("^legacy"),
CreateMissingGroups: true,
},
},
)
ctx := testutil.Context(t, testutil.WaitSuperLong)
user := dbgen.User(t, db, database.User{})
var asserts []func(t *testing.T)
// The default org is also going to do something
def := orgSetupDefinition{
Name: "DefaultOrg",
GroupNames: map[string]bool{
"legacy-foo": false,
"legacy-baz": true,
"random": true,
},
// No settings, because they come from the deployment values
Settings: nil,
ExpectedGroups: nil,
ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"},
}
//nolint:gocritic // testing
defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
require.NoError(t, err)
SetupOrganization(t, s, db, user, defOrg.ID, def)
asserts = append(asserts, func(t *testing.T) {
t.Run(def.Name, func(t *testing.T) {
t.Parallel()
def.Assert(t, defOrg.ID, db, user)
})
})
for _, tc := range testCases {
tc := tc
orgID := uuid.New()
SetupOrganization(t, s, db, user, orgID, tc)
asserts = append(asserts, func(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
tc.Assert(t, orgID, db, user)
})
})
}
asserts = append(asserts, func(t *testing.T) {
t.Helper()
def.Assert(t, defOrg.ID, db, user)
})
// Do the group sync!
err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{
SyncEnabled: true,
MergedClaims: userClaims,
})
require.NoError(t, err)
for _, assert := range asserts {
assert(t)
}
})
}
func TestSyncDisabled(t *testing.T) {
t.Parallel()
if dbtestutil.WillUsePostgres() {
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
}
db, _ := dbtestutil.NewDB(t)
manager := runtimeconfig.NewManager()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
manager,
idpsync.DeploymentSyncSettings{},
)
ids := coderdtest.NewDeterministicUUIDGenerator()
ctx := testutil.Context(t, testutil.WaitSuperLong)
user := dbgen.User(t, db, database.User{})
orgID := uuid.New()
def := orgSetupDefinition{
Name: "SyncDisabled",
Groups: map[uuid.UUID]bool{
ids.ID("foo"): true,
ids.ID("bar"): true,
ids.ID("baz"): false,
ids.ID("bop"): false,
},
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("foo")},
"baz": {ids.ID("baz")},
},
},
ExpectedGroups: []uuid.UUID{
ids.ID("foo"),
ids.ID("bar"),
},
}
SetupOrganization(t, s, db, user, orgID, def)
// Do the group sync!
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
SyncEnabled: false,
MergedClaims: jwt.MapClaims{
"groups": []string{"baz", "bop"},
},
})
require.NoError(t, err)
def.Assert(t, orgID, db, user)
}
// TestApplyGroupDifference is mainly testing the database functions
func TestApplyGroupDifference(t *testing.T) {
t.Parallel()
ids := coderdtest.NewDeterministicUUIDGenerator()
testCase := []struct {
Name string
Before map[uuid.UUID]bool
Add []uuid.UUID
Remove []uuid.UUID
Expect []uuid.UUID
}{
{
Name: "Empty",
},
{
Name: "AddFromNone",
Before: map[uuid.UUID]bool{
ids.ID("g1"): false,
},
Add: []uuid.UUID{
ids.ID("g1"),
},
Expect: []uuid.UUID{
ids.ID("g1"),
},
},
{
Name: "AddSome",
Before: map[uuid.UUID]bool{
ids.ID("g1"): true,
ids.ID("g2"): false,
ids.ID("g3"): false,
uuid.New(): false,
},
Add: []uuid.UUID{
ids.ID("g2"),
ids.ID("g3"),
},
Expect: []uuid.UUID{
ids.ID("g1"),
ids.ID("g2"),
ids.ID("g3"),
},
},
{
Name: "RemoveAll",
Before: map[uuid.UUID]bool{
uuid.New(): false,
ids.ID("g2"): true,
ids.ID("g3"): true,
},
Remove: []uuid.UUID{
ids.ID("g2"),
ids.ID("g3"),
},
Expect: []uuid.UUID{},
},
{
Name: "Mixed",
Before: map[uuid.UUID]bool{
// adds
ids.ID("a1"): true,
ids.ID("a2"): true,
ids.ID("a3"): false,
ids.ID("a4"): false,
// removes
ids.ID("r1"): true,
ids.ID("r2"): true,
ids.ID("r3"): false,
ids.ID("r4"): false,
// stable
ids.ID("s1"): true,
ids.ID("s2"): true,
// noise
uuid.New(): false,
uuid.New(): false,
},
Add: []uuid.UUID{
ids.ID("a1"), ids.ID("a2"),
ids.ID("a3"), ids.ID("a4"),
// Double up to try and confuse
ids.ID("a1"),
ids.ID("a4"),
},
Remove: []uuid.UUID{
ids.ID("r1"), ids.ID("r2"),
ids.ID("r3"), ids.ID("r4"),
// Double up to try and confuse
ids.ID("r1"),
ids.ID("r4"),
},
Expect: []uuid.UUID{
ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"),
ids.ID("s1"), ids.ID("s2"),
},
},
}
for _, tc := range testCase {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
mgr := runtimeconfig.NewManager()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
//nolint:gocritic // testing
ctx = dbauthz.AsSystemRestricted(ctx)
org := dbgen.Organization(t, db, database.Organization{})
_, err := db.InsertAllUsersGroup(ctx, org.ID)
require.NoError(t, err)
user := dbgen.User(t, db, database.User{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
for gid, in := range tc.Before {
group := dbgen.Group(t, db, database.Group{
ID: gid,
OrganizationID: org.ID,
})
if in {
_ = dbgen.GroupMember(t, db, database.GroupMemberTable{
UserID: user.ID,
GroupID: group.ID,
})
}
}
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t)))
err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove)
require.NoError(t, err)
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
HasMemberID: user.ID,
})
require.NoError(t, err)
// assert
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
return g.Group.ID
})
// Add everyone group
require.ElementsMatch(t, append(tc.Expect, org.ID), found)
})
}
}
func TestExpectedGroupEqual(t *testing.T) {
t.Parallel()
ids := coderdtest.NewDeterministicUUIDGenerator()
testCases := []struct {
Name string
A idpsync.ExpectedGroup
B idpsync.ExpectedGroup
Equal bool
}{
{
Name: "Empty",
A: idpsync.ExpectedGroup{},
B: idpsync.ExpectedGroup{},
Equal: true,
},
{
Name: "DifferentOrgs",
A: idpsync.ExpectedGroup{
OrganizationID: uuid.New(),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: uuid.New(),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
Equal: false,
},
{
Name: "SameID",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
Equal: true,
},
{
Name: "DifferentIDs",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(uuid.New()),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(uuid.New()),
GroupName: nil,
},
Equal: false,
},
{
Name: "SameName",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("foo"),
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("foo"),
},
Equal: true,
},
{
Name: "DifferentName",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("foo"),
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("bar"),
},
Equal: false,
},
// Edge cases
{
// A bit strange, but valid as ID takes priority.
// We assume 2 groups with the same ID are equal, even if
// their names are different. Names are mutable, IDs are not,
// so there is 0% chance they are different groups.
Name: "DifferentIDSameName",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: ptr.Ref("foo"),
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: ptr.Ref("bar"),
},
Equal: true,
},
{
Name: "MixedNils",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("bar"),
},
Equal: false,
},
{
Name: "NoComparable",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: nil,
},
Equal: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.Equal, tc.A.Equal(tc.B))
})
}
}
func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) {
t.Helper()
// Account that the org might be the default organization
org, err := db.GetOrganizationByID(context.Background(), orgID)
if xerrors.Is(err, sql.ErrNoRows) {
org = dbgen.Organization(t, db, database.Organization{
ID: orgID,
})
}
_, err = db.InsertAllUsersGroup(context.Background(), org.ID)
if !database.IsUniqueViolation(err) {
require.NoError(t, err, "Everyone group for an org")
}
manager := runtimeconfig.NewManager()
orgResolver := manager.OrganizationResolver(db, org.ID)
err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings)
require.NoError(t, err)
if !def.NotMember {
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
}
for groupID, in := range def.Groups {
dbgen.Group(t, db, database.Group{
ID: groupID,
OrganizationID: org.ID,
})
if in {
dbgen.GroupMember(t, db, database.GroupMemberTable{
UserID: user.ID,
GroupID: groupID,
})
}
}
for groupName, in := range def.GroupNames {
group := dbgen.Group(t, db, database.Group{
Name: groupName,
OrganizationID: org.ID,
})
if in {
dbgen.GroupMember(t, db, database.GroupMemberTable{
UserID: user.ID,
GroupID: group.ID,
})
}
}
}
type orgSetupDefinition struct {
Name string
// True if the user is a member of the group
Groups map[uuid.UUID]bool
GroupNames map[string]bool
NotMember bool
Settings *idpsync.GroupSyncSettings
ExpectedGroups []uuid.UUID
ExpectedGroupNames []string
}
func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) {
t.Helper()
ctx := context.Background()
members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: orgID,
UserID: user.ID,
})
require.NoError(t, err)
if o.NotMember {
require.Len(t, members, 0, "should not be a member")
} else {
require.Len(t, members, 1, "should be a member")
}
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
OrganizationID: orgID,
HasMemberID: user.ID,
})
require.NoError(t, err)
if o.ExpectedGroups == nil {
o.ExpectedGroups = make([]uuid.UUID, 0)
}
if len(o.ExpectedGroupNames) > 0 && len(o.ExpectedGroups) > 0 {
t.Fatal("ExpectedGroups and ExpectedGroupNames are mutually exclusive")
}
// Everyone groups mess up our asserts
userGroups = slices.DeleteFunc(userGroups, func(row database.GetGroupsRow) bool {
return row.Group.ID == row.Group.OrganizationID
})
if len(o.ExpectedGroupNames) > 0 {
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) string {
return g.Group.Name
})
require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name")
require.Len(t, o.ExpectedGroups, 0, "ExpectedGroups should be empty")
} else {
// Check by ID, recommended
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
return g.Group.ID
})
require.ElementsMatch(t, o.ExpectedGroups, found, "user groups")
require.Len(t, o.ExpectedGroupNames, 0, "ExpectedGroupNames should be empty")
}
}

View File

@ -3,6 +3,7 @@ package idpsync
import (
"context"
"net/http"
"regexp"
"strings"
"github.com/golang-jwt/jwt/v4"
@ -12,6 +13,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/site"
)
@ -25,21 +27,34 @@ type IDPSync interface {
OrganizationSyncEnabled() bool
// ParseOrganizationClaims takes claims from an OIDC provider, and returns the
// organization sync params for assigning users into organizations.
ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError)
ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (OrganizationParams, *HTTPError)
// SyncOrganizations assigns and removed users from organizations based on the
// provided params.
SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error
GroupSyncEnabled() bool
// ParseGroupClaims takes claims from an OIDC provider, and returns the params
// for group syncing. Most of the logic happens in SyncGroups.
ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError)
// SyncGroups assigns and removes users from groups based on the provided params.
SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error
// GroupSyncSettings is exposed for the API to implement CRUD operations
// on the settings used by IDPSync. This entry is thread safe and can be
// accessed concurrently. The settings are stored in the database.
GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings]
}
// AGPLIDPSync is the configuration for syncing user information from an external
// IDP. All related code to syncing user information should be in this package.
type AGPLIDPSync struct {
Logger slog.Logger
Logger slog.Logger
Manager *runtimeconfig.Manager
SyncSettings
}
type SyncSettings struct {
// DeploymentSyncSettings are static and are sourced from the deployment config.
type DeploymentSyncSettings struct {
// OrganizationField selects the claim field to be used as the created user's
// organizations. If the field is the empty string, then no organization updates
// will ever come from the OIDC provider.
@ -50,23 +65,62 @@ type SyncSettings struct {
// placed into the default organization. This is mostly a hack to support
// legacy deployments.
OrganizationAssignDefault bool
// GroupField at the deployment level is used for deployment level group claim
// settings.
GroupField string
// GroupAllowList (if set) will restrict authentication to only users who
// have at least one group in this list.
// A map representation is used for easier lookup.
GroupAllowList map[string]struct{}
// Legacy deployment settings that only apply to the default org.
Legacy DefaultOrgLegacySettings
}
type OrganizationParams struct {
// SyncEnabled if false will skip syncing the user's organizations.
SyncEnabled bool
// IncludeDefault is primarily for single org deployments. It will ensure
// a user is always inserted into the default org.
IncludeDefault bool
// Organizations is the list of organizations the user should be a member of
// assuming syncing is turned on.
Organizations []uuid.UUID
type DefaultOrgLegacySettings struct {
GroupField string
GroupMapping map[string]string
GroupFilter *regexp.Regexp
CreateMissingGroups bool
}
func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync {
func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings {
if dv == nil {
panic("Developer error: DeploymentValues should not be nil")
}
return DeploymentSyncSettings{
OrganizationField: dv.OIDC.OrganizationField.Value(),
OrganizationMapping: dv.OIDC.OrganizationMapping.Value,
OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(),
// TODO: Separate group field for allow list from default org.
// Right now you cannot disable group sync from the default org and
// configure an allow list.
GroupField: dv.OIDC.GroupField.Value(),
GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()),
Legacy: DefaultOrgLegacySettings{
GroupField: dv.OIDC.GroupField.Value(),
GroupMapping: dv.OIDC.GroupMapping.Value,
GroupFilter: dv.OIDC.GroupRegexFilter.Value(),
CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(),
},
}
}
type SyncSettings struct {
DeploymentSyncSettings
Group runtimeconfig.RuntimeEntry[*GroupSyncSettings]
}
func NewAGPLSync(logger slog.Logger, manager *runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync {
return &AGPLIDPSync{
Logger: logger.Named("idp-sync"),
SyncSettings: settings,
Logger: logger.Named("idp-sync"),
Manager: manager,
SyncSettings: SyncSettings{
DeploymentSyncSettings: settings,
Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings"),
},
}
}

View File

@ -16,6 +16,17 @@ import (
"github.com/coder/coder/v2/coderd/util/slice"
)
type OrganizationParams struct {
// SyncEnabled if false will skip syncing the user's organizations.
SyncEnabled bool
// IncludeDefault is primarily for single org deployments. It will ensure
// a user is always inserted into the default org.
IncludeDefault bool
// Organizations is the list of organizations the user should be a member of
// assuming syncing is turned on.
Organizations []uuid.UUID
}
func (AGPLIDPSync) OrganizationSyncEnabled() bool {
// AGPL does not support syncing organizations.
return false

View File

@ -9,6 +9,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/testutil"
)
@ -18,11 +19,13 @@ func TestParseOrganizationClaims(t *testing.T) {
t.Run("SingleOrgDeployment", func(t *testing.T) {
t.Parallel()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{
OrganizationField: "",
OrganizationMapping: nil,
OrganizationAssignDefault: true,
})
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{
OrganizationField: "",
OrganizationMapping: nil,
OrganizationAssignDefault: true,
})
ctx := testutil.Context(t, testutil.WaitMedium)
@ -38,13 +41,15 @@ func TestParseOrganizationClaims(t *testing.T) {
t.Parallel()
// AGPL has limited behavior
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{
OrganizationField: "orgs",
OrganizationMapping: map[string][]uuid.UUID{
"random": {uuid.New()},
},
OrganizationAssignDefault: false,
})
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{
OrganizationField: "orgs",
OrganizationMapping: map[string][]uuid.UUID{
"random": {uuid.New()},
},
OrganizationAssignDefault: false,
})
ctx := testutil.Context(t, testutil.WaitMedium)

View File

@ -2,6 +2,7 @@ package runtimeconfig
import (
"context"
"encoding/json"
"fmt"
"golang.org/x/xerrors"
@ -93,3 +94,11 @@ func (e *RuntimeEntry[T]) name() (string, error) {
return e.n, nil
}
func JSONString(v any) string {
s, err := json.Marshal(v)
if err != nil {
return "decode failed: " + err.Error()
}
return string(s)
}

View File

@ -8,7 +8,6 @@ import (
"fmt"
"net/http"
"net/mail"
"regexp"
"sort"
"strconv"
"strings"
@ -20,7 +19,6 @@ import (
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
@ -659,6 +657,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
AvatarURL: ghUser.GetAvatarURL(),
Name: normName,
DebugContext: OauthDebugContext{},
GroupSync: idpsync.GroupParams{
SyncEnabled: false,
},
OrganizationSync: idpsync.OrganizationParams{
SyncEnabled: false,
IncludeDefault: true,
@ -739,27 +740,6 @@ type OIDCConfig struct {
// support the userinfo endpoint, or if the userinfo endpoint causes
// undesirable behavior.
IgnoreUserInfo bool
// TODO: Move all idp fields into the IDPSync struct
// GroupField selects the claim field to be used as the created user's
// groups. If the group field is the empty string, then no group updates
// will ever come from the OIDC provider.
GroupField string
// CreateMissingGroups controls whether groups returned by the OIDC provider
// are automatically created in Coder if they are missing.
CreateMissingGroups bool
// GroupFilter is a regular expression that filters the groups returned by
// the OIDC provider. Any group not matched by this regex will be ignored.
// If the group filter is nil, then no group filtering will occur.
GroupFilter *regexp.Regexp
// GroupAllowList is a list of groups that are allowed to log in.
// If the list length is 0, then the allow list will not be applied and
// this feature is disabled.
GroupAllowList map[string]bool
// GroupMapping controls how groups returned by the OIDC provider get mapped
// to groups within Coder.
// map[oidcGroupName]coderGroupName
GroupMapping map[string]string
// UserRoleField selects the claim field to be used as the created user's
// roles. If the field is the empty string, then no role updates
// will ever come from the OIDC provider.
@ -1002,11 +982,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}
ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name))
usingGroups, groups, groupErr := api.oidcGroups(ctx, mergedClaims)
if groupErr != nil {
groupErr.Write(rw, r)
return
}
roles, roleErr := api.oidcRoles(ctx, mergedClaims)
if roleErr != nil {
@ -1030,6 +1005,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
return
}
groupSync, groupSyncErr := api.IDPSync.ParseGroupClaims(ctx, mergedClaims)
if groupSyncErr != nil {
groupSyncErr.Write(rw, r)
return
}
// If a new user is authenticating for the first time
// the audit action is 'register', not 'login'
if user.ID == uuid.Nil {
@ -1037,23 +1018,20 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}
params := (&oauthLoginParams{
User: user,
Link: link,
State: state,
LinkedID: oidcLinkedID(idToken),
LoginType: database.LoginTypeOIDC,
AllowSignups: api.OIDCConfig.AllowSignups,
Email: email,
Username: username,
Name: name,
AvatarURL: picture,
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
Roles: roles,
UsingGroups: usingGroups,
Groups: groups,
OrganizationSync: orgSync,
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
GroupFilter: api.OIDCConfig.GroupFilter,
User: user,
Link: link,
State: state,
LinkedID: oidcLinkedID(idToken),
LoginType: database.LoginTypeOIDC,
AllowSignups: api.OIDCConfig.AllowSignups,
Email: email,
Username: username,
Name: name,
AvatarURL: picture,
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
Roles: roles,
OrganizationSync: orgSync,
GroupSync: groupSync,
DebugContext: OauthDebugContext{
IDTokenClaims: idtokenClaims,
UserInfoClaims: userInfoClaims,
@ -1089,79 +1067,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
}
// oidcGroups returns the groups for the user from the OIDC claims.
func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interface{}) (bool, []string, *idpsync.HTTPError) {
logger := api.Logger.Named(userAuthLoggerName)
usingGroups := false
var groups []string
// If the GroupField is the empty string, then groups from OIDC are not used.
// This is so we can support manual group assignment.
if api.OIDCConfig.GroupField != "" {
// If the allow list is empty, then the user is allowed to log in.
// Otherwise, they must belong to at least 1 group in the allow list.
inAllowList := len(api.OIDCConfig.GroupAllowList) == 0
usingGroups = true
groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField]
if ok {
parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw)
if err != nil {
api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims",
slog.F("type", fmt.Sprintf("%T", groupsRaw)),
slog.Error(err),
)
return false, nil, &idpsync.HTTPError{
Code: http.StatusBadRequest,
Msg: "Failed to sync groups from OIDC claims",
Detail: err.Error(),
RenderStaticPage: false,
}
}
api.Logger.Debug(ctx, "groups returned in oidc claims",
slog.F("len", len(parsedGroups)),
slog.F("groups", parsedGroups),
)
for _, group := range parsedGroups {
if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok {
group = mappedGroup
}
if _, ok := api.OIDCConfig.GroupAllowList[group]; ok {
inAllowList = true
}
groups = append(groups, group)
}
}
if !inAllowList {
logger.Debug(ctx, "oidc group claim not in allow list, rejecting login",
slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)),
slog.F("user_group_count", len(groups)),
)
detail := "Ask an administrator to add one of your groups to the whitelist"
if len(groups) == 0 {
detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login."
}
return usingGroups, groups, &idpsync.HTTPError{
Code: http.StatusForbidden,
Msg: "Not a member of an allowed group",
Detail: detail,
RenderStaticPage: true,
}
}
}
// This conditional is purely to warn the user they might have misconfigured their OIDC
// configuration.
if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists {
logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings")
}
return usingGroups, groups, nil
}
// oidcRoles returns the roles for the user from the OIDC claims.
// If the function returns false, then the caller should return early.
// All writes to the response writer are handled by this function.
@ -1276,14 +1181,7 @@ type oauthLoginParams struct {
AvatarURL string
// OrganizationSync has the organizations that the user will be assigned to.
OrganizationSync idpsync.OrganizationParams
// Is UsingGroups is true, then the user will be assigned
// to the Groups provided.
UsingGroups bool
CreateMissingGroups bool
// These are the group names from the IDP. Internally, they will map to
// some organization groups.
Groups []string
GroupFilter *regexp.Regexp
GroupSync idpsync.GroupParams
// Is UsingRoles is true, then the user will be assigned
// the roles provided.
UsingRoles bool
@ -1489,53 +1387,11 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
return xerrors.Errorf("sync organizations: %w", err)
}
// Ensure groups are correct.
// This places all groups into the default organization.
// To go multi-org, we need to add a mapping feature here to know which
// groups go to which orgs.
if params.UsingGroups {
filtered := params.Groups
if params.GroupFilter != nil {
filtered = make([]string, 0, len(params.Groups))
for _, group := range params.Groups {
if params.GroupFilter.MatchString(group) {
filtered = append(filtered, group)
}
}
}
//nolint:gocritic // No user present in the context.
defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
if err != nil {
// If there is no default org, then we can't assign groups.
// By default, we assume all groups belong to the default org.
return xerrors.Errorf("get default organization: %w", err)
}
//nolint:gocritic // No user present in the context.
memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{
UserID: user.ID,
OrganizationID: uuid.Nil,
})
if err != nil {
return xerrors.Errorf("get organization memberships: %w", err)
}
// If the user is not in the default organization, then we can't assign groups.
// A user cannot be in groups to an org they are not a member of.
if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool {
return member.OrganizationMember.OrganizationID == defaultOrganization.ID
}) {
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
}
//nolint:gocritic
err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{
defaultOrganization.ID: filtered,
}, params.CreateMissingGroups)
if err != nil {
return xerrors.Errorf("set user groups: %w", err)
}
// Group sync needs to occur after org sync, since a user can join an org,
// then have their groups sync to said org.
err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync)
if err != nil {
return xerrors.Errorf("sync groups: %w", err)
}
// Ensure roles are correct.