feat: add ability for users to convert their password login type to oauth/github login (#8105)

* Currently toggled by experiment flag

---------

Co-authored-by: Bruno Quaresma <bruno@coder.com>
This commit is contained in:
Steven Masley
2023-06-30 08:38:48 -04:00
committed by GitHub
parent 357f3b38f7
commit b5f26d9bdf
50 changed files with 2043 additions and 261 deletions

View File

@ -1041,6 +1041,13 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
return q.db.GetLogoURL(ctx)
}
func (q *querier) GetOAuthSigningKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
return "", err
}
return q.db.GetOAuthSigningKey(ctx)
}
func (q *querier) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
return fetch(q.log, q.auth, q.db.GetOrganizationByID)(ctx, id)
}
@ -2351,6 +2358,13 @@ func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUse
return q.db.UpdateUserLinkedID(ctx, arg)
}
func (q *querier) UpdateUserLoginType(ctx context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
return database.User{}, err
}
return q.db.UpdateUserLoginType(ctx, arg)
}
func (q *querier) UpdateUserProfile(ctx context.Context, arg database.UpdateUserProfileParams) (database.User, error) {
u, err := q.db.GetUserByID(ctx, arg.ID)
if err != nil {
@ -2594,6 +2608,13 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error {
return q.db.UpsertLogoURL(ctx, value)
}
func (q *querier) UpsertOAuthSigningKey(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UpsertOAuthSigningKey(ctx, value)
}
func (q *querier) UpsertServiceBanner(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceDeploymentValues); err != nil {
return err

View File

@ -139,7 +139,6 @@ type data struct {
workspaceResources []database.WorkspaceResource
workspaces []database.Workspace
workspaceProxies []database.WorkspaceProxy
// Locks is a map of lock names. Any keys within the map are currently
// locked.
locks map[int64]struct{}
@ -149,6 +148,7 @@ type data struct {
serviceBanner []byte
logoURL string
appSecurityKey string
oauthSigningKey string
lastLicenseID int32
defaultProxyDisplayName string
defaultProxyIconURL string
@ -1868,6 +1868,13 @@ func (q *fakeQuerier) GetLogoURL(_ context.Context) (string, error) {
return q.logoURL, nil
}
func (q *fakeQuerier) GetOAuthSigningKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.oauthSigningKey, nil
}
func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -4877,6 +4884,27 @@ func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.Upda
return database.UserLink{}, sql.ErrNoRows
}
func (q *fakeQuerier) UpdateUserLoginType(_ context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) {
if err := validateDatabaseType(arg); err != nil {
return database.User{}, err
}
q.mutex.Lock()
defer q.mutex.Unlock()
for i, u := range q.users {
if u.ID == arg.UserID {
u.LoginType = arg.NewLoginType
if arg.NewLoginType != database.LoginTypePassword {
u.HashedPassword = []byte{}
}
q.users[i] = u
return u, nil
}
}
return database.User{}, sql.ErrNoRows
}
func (q *fakeQuerier) UpdateUserProfile(_ context.Context, arg database.UpdateUserProfileParams) (database.User, error) {
if err := validateDatabaseType(arg); err != nil {
return database.User{}, err
@ -5334,6 +5362,14 @@ func (q *fakeQuerier) UpsertLogoURL(_ context.Context, data string) error {
return nil
}
func (q *fakeQuerier) UpsertOAuthSigningKey(_ context.Context, value string) error {
q.mutex.Lock()
defer q.mutex.Unlock()
q.oauthSigningKey = value
return nil
}
func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error {
q.mutex.RLock()
defer q.mutex.RUnlock()

View File

@ -199,7 +199,7 @@ func User(t testing.TB, db database.Store, orig database.User) database.User {
ID: takeFirst(orig.ID, uuid.New()),
Email: takeFirst(orig.Email, namesgenerator.GetRandomName(1)),
Username: takeFirst(orig.Username, namesgenerator.GetRandomName(1)),
HashedPassword: takeFirstSlice(orig.HashedPassword, []byte{}),
HashedPassword: takeFirstSlice(orig.HashedPassword, []byte(must(cryptorand.String(32)))),
CreatedAt: takeFirst(orig.CreatedAt, database.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, database.Now()),
RBACRoles: takeFirstSlice(orig.RBACRoles, []string{}),

View File

@ -455,6 +455,13 @@ func (m metricsStore) GetLogoURL(ctx context.Context) (string, error) {
return url, err
}
func (m metricsStore) GetOAuthSigningKey(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetOAuthSigningKey(ctx)
m.queryLatencies.WithLabelValues("GetOAuthSigningKey").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
start := time.Now()
organization, err := m.s.GetOrganizationByID(ctx, id)
@ -1426,6 +1433,13 @@ func (m metricsStore) UpdateUserLinkedID(ctx context.Context, arg database.Updat
return link, err
}
func (m metricsStore) UpdateUserLoginType(ctx context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserLoginType(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserLoginType").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) UpdateUserProfile(ctx context.Context, arg database.UpdateUserProfileParams) (database.User, error) {
start := time.Now()
user, err := m.s.UpdateUserProfile(ctx, arg)
@ -1594,6 +1608,13 @@ func (m metricsStore) UpsertLogoURL(ctx context.Context, value string) error {
return r0
}
func (m metricsStore) UpsertOAuthSigningKey(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertOAuthSigningKey(ctx, value)
m.queryLatencies.WithLabelValues("UpsertOAuthSigningKey").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) UpsertServiceBanner(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertServiceBanner(ctx, value)

View File

@ -821,6 +821,21 @@ func (mr *MockStoreMockRecorder) GetLogoURL(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), arg0)
}
// GetOAuthSigningKey mocks base method.
func (m *MockStore) GetOAuthSigningKey(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOAuthSigningKey", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOAuthSigningKey indicates an expected call of GetOAuthSigningKey.
func (mr *MockStoreMockRecorder) GetOAuthSigningKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuthSigningKey", reflect.TypeOf((*MockStore)(nil).GetOAuthSigningKey), arg0)
}
// GetOrganizationByID mocks base method.
func (m *MockStore) GetOrganizationByID(arg0 context.Context, arg1 uuid.UUID) (database.Organization, error) {
m.ctrl.T.Helper()
@ -2949,6 +2964,21 @@ func (mr *MockStoreMockRecorder) UpdateUserLinkedID(arg0, arg1 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserLinkedID", reflect.TypeOf((*MockStore)(nil).UpdateUserLinkedID), arg0, arg1)
}
// UpdateUserLoginType mocks base method.
func (m *MockStore) UpdateUserLoginType(arg0 context.Context, arg1 database.UpdateUserLoginTypeParams) (database.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserLoginType", arg0, arg1)
ret0, _ := ret[0].(database.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserLoginType indicates an expected call of UpdateUserLoginType.
func (mr *MockStoreMockRecorder) UpdateUserLoginType(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserLoginType", reflect.TypeOf((*MockStore)(nil).UpdateUserLoginType), arg0, arg1)
}
// UpdateUserProfile mocks base method.
func (m *MockStore) UpdateUserProfile(arg0 context.Context, arg1 database.UpdateUserProfileParams) (database.User, error) {
m.ctrl.T.Helper()
@ -3292,6 +3322,20 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(arg0, arg1 interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), arg0, arg1)
}
// UpsertOAuthSigningKey mocks base method.
func (m *MockStore) UpsertOAuthSigningKey(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertOAuthSigningKey", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertOAuthSigningKey indicates an expected call of UpsertOAuthSigningKey.
func (mr *MockStoreMockRecorder) UpsertOAuthSigningKey(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertOAuthSigningKey", reflect.TypeOf((*MockStore)(nil).UpsertOAuthSigningKey), arg0, arg1)
}
// UpsertServiceBanner mocks base method.
func (m *MockStore) UpsertServiceBanner(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()

View File

@ -99,7 +99,8 @@ CREATE TYPE resource_type AS ENUM (
'group',
'workspace_build',
'license',
'workspace_proxy'
'workspace_proxy',
'convert_login'
);
CREATE TYPE startup_script_behavior AS ENUM (

View File

@ -0,0 +1 @@
-- Nothing to do

View File

@ -0,0 +1,2 @@
-- This has to be outside a transaction
ALTER TYPE resource_type ADD VALUE IF NOT EXISTS 'convert_login';

View File

@ -891,6 +891,7 @@ const (
ResourceTypeWorkspaceBuild ResourceType = "workspace_build"
ResourceTypeLicense ResourceType = "license"
ResourceTypeWorkspaceProxy ResourceType = "workspace_proxy"
ResourceTypeConvertLogin ResourceType = "convert_login"
)
func (e *ResourceType) Scan(src interface{}) error {
@ -940,7 +941,8 @@ func (e ResourceType) Valid() bool {
ResourceTypeGroup,
ResourceTypeWorkspaceBuild,
ResourceTypeLicense,
ResourceTypeWorkspaceProxy:
ResourceTypeWorkspaceProxy,
ResourceTypeConvertLogin:
return true
}
return false
@ -959,6 +961,7 @@ func AllResourceTypeValues() []ResourceType {
ResourceTypeWorkspaceBuild,
ResourceTypeLicense,
ResourceTypeWorkspaceProxy,
ResourceTypeConvertLogin,
}
}

View File

@ -81,6 +81,7 @@ type sqlcQuerier interface {
GetLicenseByID(ctx context.Context, id int32) (License, error)
GetLicenses(ctx context.Context) ([]License, error)
GetLogoURL(ctx context.Context) (string, error)
GetOAuthSigningKey(ctx context.Context) (string, error)
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error)
@ -239,6 +240,7 @@ type sqlcQuerier interface {
UpdateUserLastSeenAt(ctx context.Context, arg UpdateUserLastSeenAtParams) (User, error)
UpdateUserLink(ctx context.Context, arg UpdateUserLinkParams) (UserLink, error)
UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinkedIDParams) (UserLink, error)
UpdateUserLoginType(ctx context.Context, arg UpdateUserLoginTypeParams) (User, error)
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
@ -267,6 +269,7 @@ type sqlcQuerier interface {
UpsertDefaultProxy(ctx context.Context, arg UpsertDefaultProxyParams) error
UpsertLastUpdateCheck(ctx context.Context, value string) error
UpsertLogoURL(ctx context.Context, value string) error
UpsertOAuthSigningKey(ctx context.Context, value string) error
UpsertServiceBanner(ctx context.Context, value string) error
UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error)
UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error)

View File

@ -441,6 +441,53 @@ func TestUserLastSeenFilter(t *testing.T) {
})
}
func TestUserChangeLoginType(t *testing.T) {
t.Parallel()
if testing.Short() {
t.SkipNow()
}
sqlDB := testSQLDB(t)
err := migrations.Up(sqlDB)
require.NoError(t, err)
db := database.New(sqlDB)
ctx := context.Background()
alice := dbgen.User(t, db, database.User{
LoginType: database.LoginTypePassword,
})
bob := dbgen.User(t, db, database.User{
LoginType: database.LoginTypePassword,
})
bobExpPass := bob.HashedPassword
require.NotEmpty(t, alice.HashedPassword, "hashed password should not start empty")
require.NotEmpty(t, bob.HashedPassword, "hashed password should not start empty")
alice, err = db.UpdateUserLoginType(ctx, database.UpdateUserLoginTypeParams{
NewLoginType: database.LoginTypeOIDC,
UserID: alice.ID,
})
require.NoError(t, err)
require.Empty(t, alice.HashedPassword, "hashed password should be empty")
// First check other users are not affected
bob, err = db.GetUserByID(ctx, bob.ID)
require.NoError(t, err)
require.Equal(t, bobExpPass, bob.HashedPassword, "hashed password should not change")
// Then check password -> password is a noop
bob, err = db.UpdateUserLoginType(ctx, database.UpdateUserLoginTypeParams{
NewLoginType: database.LoginTypePassword,
UserID: bob.ID,
})
require.NoError(t, err)
bob, err = db.GetUserByID(ctx, bob.ID)
require.NoError(t, err)
require.Equal(t, bobExpPass, bob.HashedPassword, "hashed password should not change")
}
func requireUsersMatch(t testing.TB, expected []database.User, found []database.GetUsersRow, msg string) {
t.Helper()
require.ElementsMatch(t, expected, database.ConvertUserRows(found), msg)

View File

@ -3218,6 +3218,17 @@ func (q *sqlQuerier) GetLogoURL(ctx context.Context) (string, error) {
return value, err
}
const getOAuthSigningKey = `-- name: GetOAuthSigningKey :one
SELECT value FROM site_configs WHERE key = 'oauth_signing_key'
`
func (q *sqlQuerier) GetOAuthSigningKey(ctx context.Context) (string, error) {
row := q.db.QueryRowContext(ctx, getOAuthSigningKey)
var value string
err := row.Scan(&value)
return value, err
}
const getServiceBanner = `-- name: GetServiceBanner :one
SELECT value FROM site_configs WHERE key = 'service_banner'
`
@ -3300,6 +3311,16 @@ func (q *sqlQuerier) UpsertLogoURL(ctx context.Context, value string) error {
return err
}
const upsertOAuthSigningKey = `-- name: UpsertOAuthSigningKey :exec
INSERT INTO site_configs (key, value) VALUES ('oauth_signing_key', $1)
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'oauth_signing_key'
`
func (q *sqlQuerier) UpsertOAuthSigningKey(ctx context.Context, value string) error {
_, err := q.db.ExecContext(ctx, upsertOAuthSigningKey, value)
return err
}
const upsertServiceBanner = `-- name: UpsertServiceBanner :exec
INSERT INTO site_configs (key, value) VALUES ('service_banner', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'service_banner'
@ -4999,7 +5020,7 @@ SELECT
FROM
users
WHERE
status = 'active'::user_status AND deleted = false
status = 'active'::user_status AND deleted = false
`
func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) {
@ -5251,9 +5272,9 @@ WHERE
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
-- everyone is a member.
-- everyone is a member.
WHEN cardinality($4 :: text[]) > 0 AND 'member' != ANY($4 :: text[]) THEN
rbac_roles && $4 :: text[]
rbac_roles && $4 :: text[]
ELSE true
END
-- Filter by last_seen
@ -5523,6 +5544,47 @@ func (q *sqlQuerier) UpdateUserLastSeenAt(ctx context.Context, arg UpdateUserLas
return i, err
}
const updateUserLoginType = `-- name: UpdateUserLoginType :one
UPDATE
users
SET
login_type = $1,
hashed_password = CASE WHEN $1 = 'password' :: login_type THEN
users.hashed_password
ELSE
-- If the login type is not password, then the password should be
-- cleared.
'':: bytea
END
WHERE
id = $2 RETURNING id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at
`
type UpdateUserLoginTypeParams struct {
NewLoginType LoginType `db:"new_login_type" json:"new_login_type"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) UpdateUserLoginType(ctx context.Context, arg UpdateUserLoginTypeParams) (User, error) {
row := q.db.QueryRowContext(ctx, updateUserLoginType, arg.NewLoginType, arg.UserID)
var i User
err := row.Scan(
&i.ID,
&i.Email,
&i.Username,
&i.HashedPassword,
&i.CreatedAt,
&i.UpdatedAt,
&i.Status,
&i.RBACRoles,
&i.LoginType,
&i.AvatarURL,
&i.Deleted,
&i.LastSeenAt,
)
return i, err
}
const updateUserProfile = `-- name: UpdateUserProfile :one
UPDATE
users

View File

@ -17,7 +17,6 @@ SELECT
COALESCE((SELECT value FROM site_configs WHERE key = 'default_proxy_icon_url'), '/emojis/1f3e1.png') :: text AS icon_url
;
-- name: InsertDeploymentID :exec
INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1);
@ -57,3 +56,10 @@ SELECT value FROM site_configs WHERE key = 'app_signing_key';
-- name: UpsertAppSecurityKey :exec
INSERT INTO site_configs (key, value) VALUES ('app_signing_key', $1)
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'app_signing_key';
-- name: GetOAuthSigningKey :one
SELECT value FROM site_configs WHERE key = 'oauth_signing_key';
-- name: UpsertOAuthSigningKey :exec
INSERT INTO site_configs (key, value) VALUES ('oauth_signing_key', $1)
ON CONFLICT (key) DO UPDATE set value = $1 WHERE site_configs.key = 'oauth_signing_key';

View File

@ -1,3 +1,18 @@
-- name: UpdateUserLoginType :one
UPDATE
users
SET
login_type = @new_login_type,
hashed_password = CASE WHEN @new_login_type = 'password' :: login_type THEN
users.hashed_password
ELSE
-- If the login type is not password, then the password should be
-- cleared.
'':: bytea
END
WHERE
id = @user_id RETURNING *;
-- name: GetUserByID :one
SELECT
*
@ -39,7 +54,7 @@ SELECT
FROM
users
WHERE
status = 'active'::user_status AND deleted = false;
status = 'active'::user_status AND deleted = false;
-- name: GetFilteredUserCount :one
-- This will never count deleted users.
@ -176,9 +191,9 @@ WHERE
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
-- everyone is a member.
-- everyone is a member.
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) THEN
rbac_roles && @rbac_role :: text[]
rbac_roles && @rbac_role :: text[]
ELSE true
END
-- Filter by last_seen

View File

@ -5,11 +5,24 @@ import (
"encoding/json"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac"
)
// AuditOAuthConvertState is never stored in the database. It is stored in a cookie
// clientside as a JWT. This type is provided for audit logging purposes.
type AuditOAuthConvertState struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
// The time at which the state string expires, a merge request times out if the user does not perform it quick enough.
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
FromLoginType LoginType `db:"from_login_type" json:"from_login_type"`
// The login type the user is converting to. Should be github or oidc.
ToLoginType LoginType `db:"to_login_type" json:"to_login_type"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
type Actions []rbac.Action
func (a *Actions) Scan(src interface{}) error {