diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 5ca9d856f6..8e40e30af9 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -3551,12 +3551,12 @@ func (q *fakeQuerier) InsertGroupMember(_ context.Context, arg database.InsertGr return nil } -func (q *fakeQuerier) DeleteGroupMember(_ context.Context, userID uuid.UUID) error { +func (q *fakeQuerier) DeleteGroupMemberFromGroup(_ context.Context, arg database.DeleteGroupMemberFromGroupParams) error { q.mutex.Lock() defer q.mutex.Unlock() for i, member := range q.groupMembers { - if member.UserID == userID { + if member.UserID == arg.UserID && member.GroupID == arg.GroupID { q.groupMembers = append(q.groupMembers[:i], q.groupMembers[i+1:]...) } } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 616912acc0..7747371624 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -23,7 +23,7 @@ type sqlcQuerier interface { DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteGroupByID(ctx context.Context, id uuid.UUID) error - DeleteGroupMember(ctx context.Context, userID uuid.UUID) error + DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteOldAgentStats(ctx context.Context) error DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index e9d9793f89..714ff68e53 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -990,15 +990,21 @@ func (q *sqlQuerier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error { return err } -const deleteGroupMember = `-- name: DeleteGroupMember :exec +const deleteGroupMemberFromGroup = `-- name: DeleteGroupMemberFromGroup :exec DELETE FROM group_members WHERE - user_id = $1 + user_id = $1 AND + group_id = $2 ` -func (q *sqlQuerier) DeleteGroupMember(ctx context.Context, userID uuid.UUID) error { - _, err := q.db.ExecContext(ctx, deleteGroupMember, userID) +type DeleteGroupMemberFromGroupParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + GroupID uuid.UUID `db:"group_id" json:"group_id"` +} + +func (q *sqlQuerier) DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error { + _, err := q.db.ExecContext(ctx, deleteGroupMemberFromGroup, arg.UserID, arg.GroupID) return err } @@ -1220,7 +1226,7 @@ INSERT INTO group_members ( user_id, group_id ) -VALUES ( $1, $2) +VALUES ($1, $2) ` type InsertGroupMemberParams struct { diff --git a/coderd/database/queries/groups.sql b/coderd/database/queries/groups.sql index 19fc7e9478..9571a652ed 100644 --- a/coderd/database/queries/groups.sql +++ b/coderd/database/queries/groups.sql @@ -85,13 +85,14 @@ INSERT INTO group_members ( user_id, group_id ) -VALUES ( $1, $2); +VALUES ($1, $2); --- name: DeleteGroupMember :exec +-- name: DeleteGroupMemberFromGroup :exec DELETE FROM group_members WHERE - user_id = $1; + user_id = $1 AND + group_id = $2; -- name: DeleteGroupByID :exec DELETE FROM diff --git a/enterprise/coderd/groups.go b/enterprise/coderd/groups.go index 7ca5b8e225..b75f29211a 100644 --- a/enterprise/coderd/groups.go +++ b/enterprise/coderd/groups.go @@ -207,16 +207,27 @@ func (api *API) patchGroup(rw http.ResponseWriter, r *http.Request) { } for _, id := range req.AddUsers { - err := tx.InsertGroupMember(ctx, database.InsertGroupMemberParams{ + userID, err := uuid.Parse(id) + if err != nil { + return xerrors.Errorf("parse user ID %q: %w", id, err) + } + err = tx.InsertGroupMember(ctx, database.InsertGroupMemberParams{ GroupID: group.ID, - UserID: uuid.MustParse(id), + UserID: userID, }) if err != nil { return xerrors.Errorf("insert group member %q: %w", id, err) } } for _, id := range req.RemoveUsers { - err := tx.DeleteGroupMember(ctx, uuid.MustParse(id)) + userID, err := uuid.Parse(id) + if err != nil { + return xerrors.Errorf("parse user ID %q: %w", id, err) + } + err = tx.DeleteGroupMemberFromGroup(ctx, database.DeleteGroupMemberFromGroupParams{ + UserID: userID, + GroupID: group.ID, + }) if err != nil { return xerrors.Errorf("insert group member %q: %w", id, err) }