feat: assign users to groups returned by OIDC provider (#5965)

This commit is contained in:
Colin Adler
2023-02-02 13:53:48 -06:00
committed by GitHub
parent 026b1cd2a4
commit 496138b086
11 changed files with 477 additions and 133 deletions

View File

@ -115,6 +115,7 @@ type Options struct {
DERPServer *derp.Server
DERPMap *tailcfg.DERPMap
SwaggerEndpoint bool
SetUserGroups func(ctx context.Context, tx database.Store, userID uuid.UUID, groupNames []string) error
// APIRateLimit is the minutely throughput rate limit per user or ip.
// Setting a rate limit <0 will disable the rate limiter across the entire
@ -202,6 +203,9 @@ func New(options *Options) *API {
if options.Auditor == nil {
options.Auditor = audit.NewNop()
}
if options.SetUserGroups == nil {
options.SetUserGroups = func(context.Context, database.Store, uuid.UUID, []string) error { return nil }
}
siteCacheDir := options.CacheDir
if siteCacheDir != "" {

View File

@ -3528,6 +3528,50 @@ func (q *fakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database
return nil
}
func (q *fakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()
var groupIDs []uuid.UUID
for _, group := range q.groups {
for _, groupName := range arg.GroupNames {
if group.Name == groupName {
groupIDs = append(groupIDs, group.ID)
}
}
}
for _, groupID := range groupIDs {
q.groupMembers = append(q.groupMembers, database.GroupMember{
UserID: arg.UserID,
GroupID: groupID,
})
}
return nil
}
func (q *fakeQuerier) DeleteGroupMembersByOrgAndUser(_ context.Context, arg database.DeleteGroupMembersByOrgAndUserParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()
newMembers := q.groupMembers[:0]
for _, member := range q.groupMembers {
if member.UserID == arg.UserID {
for _, group := range q.groups {
if group.ID == member.GroupID && group.OrganizationID == arg.OrganizationID {
continue
}
newMembers = append(newMembers, member)
}
}
}
q.groupMembers = newMembers
return nil
}
func (q *fakeQuerier) UpdateGroupByID(_ context.Context, arg database.UpdateGroupByIDParams) (database.Group, error) {
if err := validateDatabaseType(arg); err != nil {
return database.Group{}, err

View File

@ -24,6 +24,7 @@ type sqlcQuerier interface {
DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error
DeleteGroupByID(ctx context.Context, id uuid.UUID) error
DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error
DeleteGroupMembersByOrgAndUser(ctx context.Context, arg DeleteGroupMembersByOrgAndUserParams) error
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteOldAgentStats(ctx context.Context) error
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
@ -158,6 +159,8 @@ type sqlcQuerier interface {
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error)
InsertTemplateVersionParameter(ctx context.Context, arg InsertTemplateVersionParameterParams) (TemplateVersionParameter, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
// InsertUserGroupsByName adds a user to all provided groups, if they exist.
InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error)
InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error)

View File

@ -978,18 +978,6 @@ func (q *sqlQuerier) UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyPar
return i, err
}
const deleteGroupByID = `-- name: DeleteGroupByID :exec
DELETE FROM
groups
WHERE
id = $1
`
func (q *sqlQuerier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteGroupByID, id)
return err
}
const deleteGroupMemberFromGroup = `-- name: DeleteGroupMemberFromGroup :exec
DELETE FROM
group_members
@ -1008,6 +996,143 @@ func (q *sqlQuerier) DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteG
return err
}
const deleteGroupMembersByOrgAndUser = `-- name: DeleteGroupMembersByOrgAndUser :exec
DELETE FROM
group_members
USING
group_members AS gm
LEFT JOIN
groups
ON
groups.id = gm.group_id
WHERE
groups.organization_id = $1 AND
gm.user_id = $2
`
type DeleteGroupMembersByOrgAndUserParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) DeleteGroupMembersByOrgAndUser(ctx context.Context, arg DeleteGroupMembersByOrgAndUserParams) error {
_, err := q.db.ExecContext(ctx, deleteGroupMembersByOrgAndUser, arg.OrganizationID, arg.UserID)
return err
}
const getGroupMembers = `-- name: GetGroupMembers :many
SELECT
users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at
FROM
users
JOIN
group_members
ON
users.id = group_members.user_id
WHERE
group_members.group_id = $1
AND
users.status = 'active'
AND
users.deleted = 'false'
`
func (q *sqlQuerier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) {
rows, err := q.db.QueryContext(ctx, getGroupMembers, groupID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Email,
&i.Username,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
&i.LastSeenAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertGroupMember = `-- name: InsertGroupMember :exec
INSERT INTO
group_members (user_id, group_id)
VALUES
($1, $2)
`
type InsertGroupMemberParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupID uuid.UUID `db:"group_id" json:"group_id"`
}
func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error {
_, err := q.db.ExecContext(ctx, insertGroupMember, arg.UserID, arg.GroupID)
return err
}
const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec
WITH groups AS (
SELECT
id
FROM
groups
WHERE
groups.organization_id = $2 AND
groups.name = ANY($3 :: text [])
)
INSERT INTO
group_members (user_id, group_id)
SELECT
$1,
groups.id
FROM
groups
`
type InsertUserGroupsByNameParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
GroupNames []string `db:"group_names" json:"group_names"`
}
// InsertUserGroupsByName adds a user to all provided groups, if they exist.
func (q *sqlQuerier) InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error {
_, err := q.db.ExecContext(ctx, insertUserGroupsByName, arg.UserID, arg.OrganizationID, pq.Array(arg.GroupNames))
return err
}
const deleteGroupByID = `-- name: DeleteGroupByID :exec
DELETE FROM
groups
WHERE
id = $1
`
func (q *sqlQuerier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteGroupByID, id)
return err
}
const getGroupByID = `-- name: GetGroupByID :one
SELECT
id, name, organization_id, avatar_url, quota_allowance
@ -1063,59 +1188,6 @@ func (q *sqlQuerier) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrg
return i, err
}
const getGroupMembers = `-- name: GetGroupMembers :many
SELECT
users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at
FROM
users
JOIN
group_members
ON
users.id = group_members.user_id
WHERE
group_members.group_id = $1
AND
users.status = 'active'
AND
users.deleted = 'false'
`
func (q *sqlQuerier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) {
rows, err := q.db.QueryContext(ctx, getGroupMembers, groupID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []User
for rows.Next() {
var i User
if err := rows.Scan(
&i.ID,
&i.Email,
&i.Username,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
&i.LastSeenAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getGroupsByOrganizationID = `-- name: GetGroupsByOrganizationID :many
SELECT
id, name, organization_id, avatar_url, quota_allowance
@ -1163,7 +1235,7 @@ INSERT INTO groups (
organization_id
)
VALUES
( $1, 'Everyone', $1) RETURNING id, name, organization_id, avatar_url, quota_allowance
($1, 'Everyone', $1) RETURNING id, name, organization_id, avatar_url, quota_allowance
`
// We use the organization_id as the id
@ -1191,7 +1263,7 @@ INSERT INTO groups (
quota_allowance
)
VALUES
( $1, $2, $3, $4, $5) RETURNING id, name, organization_id, avatar_url, quota_allowance
($1, $2, $3, $4, $5) RETURNING id, name, organization_id, avatar_url, quota_allowance
`
type InsertGroupParams struct {
@ -1221,24 +1293,6 @@ func (q *sqlQuerier) InsertGroup(ctx context.Context, arg InsertGroupParams) (Gr
return i, err
}
const insertGroupMember = `-- name: InsertGroupMember :exec
INSERT INTO group_members (
user_id,
group_id
)
VALUES ($1, $2)
`
type InsertGroupMemberParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupID uuid.UUID `db:"group_id" json:"group_id"`
}
func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error {
_, err := q.db.ExecContext(ctx, insertGroupMember, arg.UserID, arg.GroupID)
return err
}
const updateGroupByID = `-- name: UpdateGroupByID :one
UPDATE
groups

View File

@ -0,0 +1,60 @@
-- name: GetGroupMembers :many
SELECT
users.*
FROM
users
JOIN
group_members
ON
users.id = group_members.user_id
WHERE
group_members.group_id = $1
AND
users.status = 'active'
AND
users.deleted = 'false';
-- InsertUserGroupsByName adds a user to all provided groups, if they exist.
-- name: InsertUserGroupsByName :exec
WITH groups AS (
SELECT
id
FROM
groups
WHERE
groups.organization_id = @organization_id AND
groups.name = ANY(@group_names :: text [])
)
INSERT INTO
group_members (user_id, group_id)
SELECT
@user_id,
groups.id
FROM
groups;
-- name: DeleteGroupMembersByOrgAndUser :exec
DELETE FROM
group_members
USING
group_members AS gm
LEFT JOIN
groups
ON
groups.id = gm.group_id
WHERE
groups.organization_id = @organization_id AND
gm.user_id = @user_id;
-- name: InsertGroupMember :exec
INSERT INTO
group_members (user_id, group_id)
VALUES
($1, $2);
-- name: DeleteGroupMemberFromGroup :exec
DELETE FROM
group_members
WHERE
user_id = $1 AND
group_id = $2;

View File

@ -20,22 +20,6 @@ AND
LIMIT
1;
-- name: GetGroupMembers :many
SELECT
users.*
FROM
users
JOIN
group_members
ON
users.id = group_members.user_id
WHERE
group_members.group_id = $1
AND
users.status = 'active'
AND
users.deleted = 'false';
-- name: GetGroupsByOrganizationID :many
SELECT
*
@ -55,7 +39,7 @@ INSERT INTO groups (
quota_allowance
)
VALUES
( $1, $2, $3, $4, $5) RETURNING *;
($1, $2, $3, $4, $5) RETURNING *;
-- We use the organization_id as the id
-- for simplicity since all users is
@ -67,7 +51,7 @@ INSERT INTO groups (
organization_id
)
VALUES
( sqlc.arg(organization_id), 'Everyone', sqlc.arg(organization_id)) RETURNING *;
(sqlc.arg(organization_id), 'Everyone', sqlc.arg(organization_id)) RETURNING *;
-- name: UpdateGroupByID :one
UPDATE
@ -80,20 +64,6 @@ WHERE
id = $4
RETURNING *;
-- name: InsertGroupMember :exec
INSERT INTO group_members (
user_id,
group_id
)
VALUES ($1, $2);
-- name: DeleteGroupMemberFromGroup :exec
DELETE FROM
group_members
WHERE
user_id = $1 AND
group_id = $2;
-- name: DeleteGroupByID :exec
DELETE FROM
groups

View File

@ -311,6 +311,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
if ok {
username, _ = usernameRaw.(string)
}
emailRaw, ok := claims["email"]
if !ok {
// Email is an optional claim in OIDC and
@ -326,6 +327,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}
emailRaw = username
}
email, ok := emailRaw.(string)
if !ok {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -333,6 +335,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
})
return
}
verifiedRaw, ok := claims["email_verified"]
if ok {
verified, ok := verifiedRaw.(bool)
@ -346,6 +349,26 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
api.Logger.Warn(ctx, "allowing unverified oidc email %q")
}
}
var groups []string
groupsRaw, ok := claims["groups"]
if ok {
// Convert the []interface{} we get to a []string.
groupsInterface, ok := groupsRaw.([]interface{})
if ok {
for _, groupInterface := range groupsInterface {
group, ok := groupInterface.(string)
if !ok {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid group type. Expected string, got: %t", emailRaw),
})
return
}
groups = append(groups, group)
}
}
}
// The username is a required property in Coder. We make a best-effort
// attempt at using what the claims provide, but if that fails we will
// generate a random username.
@ -359,6 +382,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
}
username = httpapi.UsernameFrom(username)
}
if len(api.OIDCConfig.EmailDomain) > 0 {
ok = false
for _, domain := range api.OIDCConfig.EmailDomain {
@ -374,6 +398,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
return
}
}
var picture string
pictureRaw, ok := claims["picture"]
if ok {
@ -388,6 +413,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
Email: email,
Username: username,
AvatarURL: picture,
Groups: groups,
})
var httpErr httpError
if xerrors.As(err, &httpErr) {
@ -425,6 +451,7 @@ type oauthLoginParams struct {
Email string
Username string
AvatarURL string
Groups []string
}
type httpError struct {
@ -546,22 +573,6 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
}
}
// LEGACY: Remove 10/2022.
// We started tracking linked IDs later so it's possible for a user to be a
// pre-existing OAuth user and not have a linked ID.
// The migration that added the user_links table could not populate
// the 'linked_id' field since it requires fields off the access token.
if link.LinkedID == "" {
link, err = tx.UpdateUserLinkedID(ctx, database.UpdateUserLinkedIDParams{
UserID: user.ID,
LoginType: params.LoginType,
LinkedID: params.LinkedID,
})
if err != nil {
return xerrors.Errorf("update user linked ID: %w", err)
}
}
if link.UserID != uuid.Nil {
link, err = tx.UpdateUserLink(ctx, database.UpdateUserLinkParams{
UserID: user.ID,
@ -575,6 +586,14 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
}
}
// Ensure groups are correct.
if len(params.Groups) > 0 {
err := api.Options.SetUserGroups(ctx, tx, user.ID, params.Groups)
if err != nil {
return xerrors.Errorf("set user groups: %w", err)
}
}
needsUpdate := false
if user.AvatarURL.String != params.AvatarURL {
user.AvatarURL = sql.NullString{

View File

@ -622,6 +622,14 @@ func TestUserOIDC(t *testing.T) {
AllowSignups: true,
AvatarURL: "/example.png",
StatusCode: http.StatusTemporaryRedirect,
}, {
Name: "GroupsDoesNothing",
IDTokenClaims: jwt.MapClaims{
"email": "coolin@coder.com",
"groups": []string{"pingpong"},
},
AllowSignups: true,
StatusCode: http.StatusTemporaryRedirect,
}} {
tc := tc
t.Run(tc.Name, func(t *testing.T) {