feat: add auto group create from OIDC (#8884)

* add flag for auto create groups
* fixup! add flag for auto create groups
* sync missing groups
Also added a regex filter to filter out groups that are not
important
This commit is contained in:
Steven Masley
2023-08-08 11:37:49 -05:00
committed by GitHub
parent 4a987e9917
commit f4122fa9f5
35 changed files with 887 additions and 128 deletions

23
coderd/apidoc/docs.go generated
View File

@ -6624,6 +6624,9 @@ const docTemplate = `{
}
}
},
"clibase.Regexp": {
"type": "object"
},
"clibase.Struct-array_codersdk_GitAuthConfig": {
"type": "object",
"properties": {
@ -8274,9 +8277,23 @@ const docTemplate = `{
},
"quota_allowance": {
"type": "integer"
},
"source": {
"$ref": "#/definitions/codersdk.GroupSource"
}
}
},
"codersdk.GroupSource": {
"type": "string",
"enum": [
"user",
"oidc"
],
"x-enum-varnames": [
"GroupSourceUser",
"GroupSourceOIDC"
]
},
"codersdk.Healthcheck": {
"type": "object",
"properties": {
@ -8583,9 +8600,15 @@ const docTemplate = `{
"email_field": {
"type": "string"
},
"group_auto_create": {
"type": "boolean"
},
"group_mapping": {
"type": "object"
},
"group_regex_filter": {
"$ref": "#/definitions/clibase.Regexp"
},
"groups_field": {
"type": "string"
},

View File

@ -5874,6 +5874,9 @@
}
}
},
"clibase.Regexp": {
"type": "object"
},
"clibase.Struct-array_codersdk_GitAuthConfig": {
"type": "object",
"properties": {
@ -7430,9 +7433,17 @@
},
"quota_allowance": {
"type": "integer"
},
"source": {
"$ref": "#/definitions/codersdk.GroupSource"
}
}
},
"codersdk.GroupSource": {
"type": "string",
"enum": ["user", "oidc"],
"x-enum-varnames": ["GroupSourceUser", "GroupSourceOIDC"]
},
"codersdk.Healthcheck": {
"type": "object",
"properties": {
@ -7703,9 +7714,15 @@
"email_field": {
"type": "string"
},
"group_auto_create": {
"type": "boolean"
},
"group_mapping": {
"type": "object"
},
"group_regex_filter": {
"$ref": "#/definitions/clibase.Regexp"
},
"groups_field": {
"type": "string"
},

View File

@ -127,8 +127,8 @@ type Options struct {
BaseDERPMap *tailcfg.DERPMap
DERPMapUpdateFrequency time.Duration
SwaggerEndpoint bool
SetUserGroups func(ctx context.Context, tx database.Store, userID uuid.UUID, groupNames []string) error
SetUserSiteRoles func(ctx context.Context, tx database.Store, userID uuid.UUID, roles []string) error
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, groupNames []string, createMissingGroups bool) error
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
// AppSecurityKey is the crypto key used to sign and encrypt tokens related to
@ -262,16 +262,16 @@ func New(options *Options) *API {
options.TracerProvider = trace.NewNoopTracerProvider()
}
if options.SetUserGroups == nil {
options.SetUserGroups = func(ctx context.Context, _ database.Store, userID uuid.UUID, groups []string) error {
options.Logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
slog.F("user_id", userID), slog.F("groups", groups),
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, groups []string, createMissingGroups bool) error {
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
slog.F("user_id", userID), slog.F("groups", groups), slog.F("create_missing_groups", createMissingGroups),
)
return nil
}
}
if options.SetUserSiteRoles == nil {
options.SetUserSiteRoles = func(ctx context.Context, _ database.Store, userID uuid.UUID, roles []string) error {
options.Logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
slog.F("user_id", userID), slog.F("roles", roles),
)
return nil

View File

@ -1853,6 +1853,13 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
return q.db.InsertLicense(ctx, arg)
}
func (q *querier) InsertMissingGroups(ctx context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) {
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.InsertMissingGroups(ctx, arg)
}
func (q *querier) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
return insert(q.log, q.auth, rbac.ResourceOrganization, q.db.InsertOrganization)(ctx, arg)
}

View File

@ -3641,6 +3641,7 @@ func (q *FakeQuerier) InsertGroup(_ context.Context, arg database.InsertGroupPar
OrganizationID: arg.OrganizationID,
AvatarURL: arg.AvatarURL,
QuotaAllowance: arg.QuotaAllowance,
Source: database.GroupSourceUser,
}
q.groups = append(q.groups, group)
@ -3693,6 +3694,45 @@ func (q *FakeQuerier) InsertLicense(
return l, nil
}
func (q *FakeQuerier) InsertMissingGroups(_ context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
}
groupNameMap := make(map[string]struct{})
for _, g := range arg.GroupNames {
groupNameMap[g] = struct{}{}
}
q.mutex.Lock()
defer q.mutex.Unlock()
for _, g := range q.groups {
if g.OrganizationID != arg.OrganizationID {
continue
}
delete(groupNameMap, g.Name)
}
newGroups := make([]database.Group, 0, len(groupNameMap))
for k := range groupNameMap {
g := database.Group{
ID: uuid.New(),
Name: k,
OrganizationID: arg.OrganizationID,
AvatarURL: "",
QuotaAllowance: 0,
DisplayName: "",
Source: arg.Source,
}
q.groups = append(q.groups, g)
newGroups = append(newGroups, g)
}
return newGroups, nil
}
func (q *FakeQuerier) InsertOrganization(_ context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
if err := validateDatabaseType(arg); err != nil {
return database.Organization{}, err

View File

@ -1110,6 +1110,13 @@ func (m metricsStore) InsertLicense(ctx context.Context, arg database.InsertLice
return license, err
}
func (m metricsStore) InsertMissingGroups(ctx context.Context, arg database.InsertMissingGroupsParams) ([]database.Group, error) {
start := time.Now()
r0, r1 := m.s.InsertMissingGroups(ctx, arg)
m.queryLatencies.WithLabelValues("InsertMissingGroups").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) InsertOrganization(ctx context.Context, arg database.InsertOrganizationParams) (database.Organization, error) {
start := time.Now()
organization, err := m.s.InsertOrganization(ctx, arg)

View File

@ -2332,6 +2332,21 @@ func (mr *MockStoreMockRecorder) InsertLicense(arg0, arg1 interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertLicense", reflect.TypeOf((*MockStore)(nil).InsertLicense), arg0, arg1)
}
// InsertMissingGroups mocks base method.
func (m *MockStore) InsertMissingGroups(arg0 context.Context, arg1 database.InsertMissingGroupsParams) ([]database.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertMissingGroups", arg0, arg1)
ret0, _ := ret[0].([]database.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertMissingGroups indicates an expected call of InsertMissingGroups.
func (mr *MockStoreMockRecorder) InsertMissingGroups(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMissingGroups", reflect.TypeOf((*MockStore)(nil).InsertMissingGroups), arg0, arg1)
}
// InsertOrganization mocks base method.
func (m *MockStore) InsertOrganization(arg0 context.Context, arg1 database.InsertOrganizationParams) (database.Organization, error) {
m.ctrl.T.Helper()

View File

@ -31,6 +31,11 @@ CREATE TYPE build_reason AS ENUM (
'autodelete'
);
CREATE TYPE group_source AS ENUM (
'user',
'oidc'
);
CREATE TYPE log_level AS ENUM (
'trace',
'debug',
@ -299,11 +304,14 @@ CREATE TABLE groups (
organization_id uuid NOT NULL,
avatar_url text DEFAULT ''::text NOT NULL,
quota_allowance integer DEFAULT 0 NOT NULL,
display_name text DEFAULT ''::text NOT NULL
display_name text DEFAULT ''::text NOT NULL,
source group_source DEFAULT 'user'::group_source NOT NULL
);
COMMENT ON COLUMN groups.display_name IS 'Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.';
COMMENT ON COLUMN groups.source IS 'Source indicates how the group was created. It can be created by a user manually, or through some system process like OIDC group sync.';
CREATE TABLE licenses (
id integer NOT NULL,
uploaded_at timestamp with time zone NOT NULL,

View File

@ -0,0 +1,8 @@
BEGIN;
ALTER TABLE groups
DROP COLUMN source;
DROP TYPE group_source;
COMMIT;

View File

@ -0,0 +1,15 @@
BEGIN;
CREATE TYPE group_source AS ENUM (
-- User created groups
'user',
-- Groups created by the system through oidc sync
'oidc'
);
ALTER TABLE groups
ADD COLUMN source group_source NOT NULL DEFAULT 'user';
COMMENT ON COLUMN groups.source IS 'Source indicates how the group was created. It can be created by a user manually, or through some system process like OIDC group sync.';
COMMIT;

View File

@ -281,6 +281,64 @@ func AllBuildReasonValues() []BuildReason {
}
}
type GroupSource string
const (
GroupSourceUser GroupSource = "user"
GroupSourceOidc GroupSource = "oidc"
)
func (e *GroupSource) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = GroupSource(s)
case string:
*e = GroupSource(s)
default:
return fmt.Errorf("unsupported scan type for GroupSource: %T", src)
}
return nil
}
type NullGroupSource struct {
GroupSource GroupSource `json:"group_source"`
Valid bool `json:"valid"` // Valid is true if GroupSource is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullGroupSource) Scan(value interface{}) error {
if value == nil {
ns.GroupSource, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.GroupSource.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullGroupSource) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.GroupSource), nil
}
func (e GroupSource) Valid() bool {
switch e {
case GroupSourceUser,
GroupSourceOidc:
return true
}
return false
}
func AllGroupSourceValues() []GroupSource {
return []GroupSource{
GroupSourceUser,
GroupSourceOidc,
}
}
type LogLevel string
const (
@ -1498,6 +1556,8 @@ type Group struct {
QuotaAllowance int32 `db:"quota_allowance" json:"quota_allowance"`
// Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.
DisplayName string `db:"display_name" json:"display_name"`
// Source indicates how the group was created. It can be created by a user manually, or through some system process like OIDC group sync.
Source GroupSource `db:"source" json:"source"`
}
type GroupMember struct {

View File

@ -206,6 +206,11 @@ type sqlcQuerier interface {
InsertGroup(ctx context.Context, arg InsertGroupParams) (Group, error)
InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error
InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error)
// Inserts any group by name that does not exist. All new groups are given
// a random uuid, are inserted into the same organization. They have the default
// values for avatar, display name, and quota allowance (all zero values).
// If the name conflicts, do nothing.
InsertMissingGroups(ctx context.Context, arg InsertMissingGroupsParams) ([]Group, error)
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error)

View File

@ -1180,7 +1180,7 @@ func (q *sqlQuerier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error {
const getGroupByID = `-- name: GetGroupByID :one
SELECT
id, name, organization_id, avatar_url, quota_allowance, display_name
id, name, organization_id, avatar_url, quota_allowance, display_name, source
FROM
groups
WHERE
@ -1199,13 +1199,14 @@ func (q *sqlQuerier) GetGroupByID(ctx context.Context, id uuid.UUID) (Group, err
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
)
return i, err
}
const getGroupByOrgAndName = `-- name: GetGroupByOrgAndName :one
SELECT
id, name, organization_id, avatar_url, quota_allowance, display_name
id, name, organization_id, avatar_url, quota_allowance, display_name, source
FROM
groups
WHERE
@ -1231,13 +1232,14 @@ func (q *sqlQuerier) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrg
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
)
return i, err
}
const getGroupsByOrganizationID = `-- name: GetGroupsByOrganizationID :many
SELECT
id, name, organization_id, avatar_url, quota_allowance, display_name
id, name, organization_id, avatar_url, quota_allowance, display_name, source
FROM
groups
WHERE
@ -1262,6 +1264,7 @@ func (q *sqlQuerier) GetGroupsByOrganizationID(ctx context.Context, organization
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
); err != nil {
return nil, err
}
@ -1283,7 +1286,7 @@ INSERT INTO groups (
organization_id
)
VALUES
($1, 'Everyone', $1) RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name
($1, 'Everyone', $1) RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name, source
`
// We use the organization_id as the id
@ -1299,6 +1302,7 @@ func (q *sqlQuerier) InsertAllUsersGroup(ctx context.Context, organizationID uui
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
)
return i, err
}
@ -1313,7 +1317,7 @@ INSERT INTO groups (
quota_allowance
)
VALUES
($1, $2, $3, $4, $5, $6) RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name
($1, $2, $3, $4, $5, $6) RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name, source
`
type InsertGroupParams struct {
@ -1342,10 +1346,70 @@ func (q *sqlQuerier) InsertGroup(ctx context.Context, arg InsertGroupParams) (Gr
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
)
return i, err
}
const insertMissingGroups = `-- name: InsertMissingGroups :many
INSERT INTO groups (
id,
name,
organization_id,
source
)
SELECT
gen_random_uuid(),
group_name,
$1,
$2
FROM
UNNEST($3 :: text[]) AS group_name
ON CONFLICT DO NOTHING
RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name, source
`
type InsertMissingGroupsParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
Source GroupSource `db:"source" json:"source"`
GroupNames []string `db:"group_names" json:"group_names"`
}
// Inserts any group by name that does not exist. All new groups are given
// a random uuid, are inserted into the same organization. They have the default
// values for avatar, display name, and quota allowance (all zero values).
// If the name conflicts, do nothing.
func (q *sqlQuerier) InsertMissingGroups(ctx context.Context, arg InsertMissingGroupsParams) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, insertMissingGroups, arg.OrganizationID, arg.Source, pq.Array(arg.GroupNames))
if err != nil {
return nil, err
}
defer rows.Close()
var items []Group
for rows.Next() {
var i Group
if err := rows.Scan(
&i.ID,
&i.Name,
&i.OrganizationID,
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateGroupByID = `-- name: UpdateGroupByID :one
UPDATE
groups
@ -1356,7 +1420,7 @@ SET
quota_allowance = $4
WHERE
id = $5
RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name
RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name, source
`
type UpdateGroupByIDParams struct {
@ -1383,6 +1447,7 @@ func (q *sqlQuerier) UpdateGroupByID(ctx context.Context, arg UpdateGroupByIDPar
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
)
return i, err
}

View File

@ -42,6 +42,28 @@ INSERT INTO groups (
VALUES
($1, $2, $3, $4, $5, $6) RETURNING *;
-- name: InsertMissingGroups :many
-- Inserts any group by name that does not exist. All new groups are given
-- a random uuid, are inserted into the same organization. They have the default
-- values for avatar, display name, and quota allowance (all zero values).
INSERT INTO groups (
id,
name,
organization_id,
source
)
SELECT
gen_random_uuid(),
group_name,
@organization_id,
@source
FROM
UNNEST(@group_names :: text[]) AS group_name
-- If the name conflicts, do nothing.
ON CONFLICT DO NOTHING
RETURNING *;
-- We use the organization_id as the id
-- for simplicity since all users is
-- every member of the org.

View File

@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/mail"
"regexp"
"sort"
"strconv"
"strings"
@ -688,6 +689,13 @@ type OIDCConfig struct {
// groups. If the group field is the empty string, then no group updates
// will ever come from the OIDC provider.
GroupField string
// CreateMissingGroups controls whether groups returned by the OIDC provider
// are automatically created in Coder if they are missing.
CreateMissingGroups bool
// GroupFilter is a regular expression that filters the groups returned by
// the OIDC provider. Any group not matched by this regex will be ignored.
// If the group filter is nil, then no group filtering will occur.
GroupFilter *regexp.Regexp
// GroupMapping controls how groups returned by the OIDC provider get mapped
// to groups within Coder.
// map[oidcGroupName]coderGroupName
@ -1029,19 +1037,21 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}
params := (&oauthLoginParams{
User: user,
Link: link,
State: state,
LinkedID: oidcLinkedID(idToken),
LoginType: database.LoginTypeOIDC,
AllowSignups: api.OIDCConfig.AllowSignups,
Email: email,
Username: username,
AvatarURL: picture,
UsingGroups: usingGroups,
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
Roles: roles,
Groups: groups,
User: user,
Link: link,
State: state,
LinkedID: oidcLinkedID(idToken),
LoginType: database.LoginTypeOIDC,
AllowSignups: api.OIDCConfig.AllowSignups,
Email: email,
Username: username,
AvatarURL: picture,
UsingGroups: usingGroups,
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
Roles: roles,
Groups: groups,
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
GroupFilter: api.OIDCConfig.GroupFilter,
}).SetInitAuditRequest(func(params *audit.RequestParams) (*audit.Request[database.User], func()) {
return audit.InitRequest[database.User](rw, params)
})
@ -1125,8 +1135,10 @@ type oauthLoginParams struct {
AvatarURL string
// Is UsingGroups is true, then the user will be assigned
// to the Groups provided.
UsingGroups bool
Groups []string
UsingGroups bool
CreateMissingGroups bool
Groups []string
GroupFilter *regexp.Regexp
// Is UsingRoles is true, then the user will be assigned
// the roles provided.
UsingRoles bool
@ -1342,8 +1354,18 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
// Ensure groups are correct.
if params.UsingGroups {
filtered := params.Groups
if params.GroupFilter != nil {
filtered = make([]string, 0, len(params.Groups))
for _, group := range params.Groups {
if params.GroupFilter.MatchString(group) {
filtered = append(filtered, group)
}
}
}
//nolint:gocritic
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), tx, user.ID, params.Groups)
err := api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered, params.CreateMissingGroups)
if err != nil {
return xerrors.Errorf("set user groups: %w", err)
}
@ -1362,7 +1384,7 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}
//nolint:gocritic
err := api.Options.SetUserSiteRoles(dbauthz.AsSystemRestricted(ctx), tx, user.ID, filtered)
err := api.Options.SetUserSiteRoles(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, filtered)
if err != nil {
return httpError{
code: http.StatusBadRequest,