mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
fix: use unique ID for linked accounts (#3441)
- move OAuth-related fields off of api_keys into a new user_links table - restrict users to single form of login - process updates to user email/usernames for OIDC - added a login_type column to users
This commit is contained in:
@ -73,6 +73,7 @@ type data struct {
|
||||
organizations []database.Organization
|
||||
organizationMembers []database.OrganizationMember
|
||||
users []database.User
|
||||
userLinks []database.UserLink
|
||||
|
||||
// New tables
|
||||
auditLogs []database.AuditLog
|
||||
@ -1454,20 +1455,16 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
|
||||
|
||||
//nolint:gosimple
|
||||
key := database.APIKey{
|
||||
ID: arg.ID,
|
||||
LifetimeSeconds: arg.LifetimeSeconds,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
IPAddress: arg.IPAddress,
|
||||
UserID: arg.UserID,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LastUsed: arg.LastUsed,
|
||||
LoginType: arg.LoginType,
|
||||
OAuthAccessToken: arg.OAuthAccessToken,
|
||||
OAuthRefreshToken: arg.OAuthRefreshToken,
|
||||
OAuthIDToken: arg.OAuthIDToken,
|
||||
OAuthExpiry: arg.OAuthExpiry,
|
||||
ID: arg.ID,
|
||||
LifetimeSeconds: arg.LifetimeSeconds,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
IPAddress: arg.IPAddress,
|
||||
UserID: arg.UserID,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LastUsed: arg.LastUsed,
|
||||
LoginType: arg.LoginType,
|
||||
}
|
||||
q.apiKeys = append(q.apiKeys, key)
|
||||
return key, nil
|
||||
@ -1744,6 +1741,7 @@ func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
|
||||
Username: arg.Username,
|
||||
Status: database.UserStatusActive,
|
||||
RBACRoles: arg.RBACRoles,
|
||||
LoginType: arg.LoginType,
|
||||
}
|
||||
q.users = append(q.users, user)
|
||||
return user, nil
|
||||
@ -1899,9 +1897,6 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
|
||||
apiKey.LastUsed = arg.LastUsed
|
||||
apiKey.ExpiresAt = arg.ExpiresAt
|
||||
apiKey.IPAddress = arg.IPAddress
|
||||
apiKey.OAuthAccessToken = arg.OAuthAccessToken
|
||||
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
|
||||
apiKey.OAuthExpiry = arg.OAuthExpiry
|
||||
q.apiKeys[index] = apiKey
|
||||
return nil
|
||||
}
|
||||
@ -2260,3 +2255,80 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) {
|
||||
|
||||
return q.deploymentID, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserLinkByLinkedID(_ context.Context, id string) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, link := range q.userLinks {
|
||||
if link.LinkedID == id {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params database.GetUserLinkByUserIDLoginTypeParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertUserLink(_ context.Context, args database.InsertUserLinkParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
//nolint:gosimple
|
||||
link := database.UserLink{
|
||||
UserID: args.UserID,
|
||||
LoginType: args.LoginType,
|
||||
LinkedID: args.LinkedID,
|
||||
OAuthAccessToken: args.OAuthAccessToken,
|
||||
OAuthRefreshToken: args.OAuthRefreshToken,
|
||||
OAuthExpiry: args.OAuthExpiry,
|
||||
}
|
||||
|
||||
q.userLinks = append(q.userLinks, link)
|
||||
|
||||
return link, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateUserLinkedID(_ context.Context, params database.UpdateUserLinkedIDParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for i, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
link.LinkedID = params.LinkedID
|
||||
|
||||
q.userLinks[i] = link
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUserLinkParams) (database.UserLink, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for i, link := range q.userLinks {
|
||||
if link.UserID == params.UserID && link.LoginType == params.LoginType {
|
||||
link.OAuthAccessToken = params.OAuthAccessToken
|
||||
link.OAuthRefreshToken = params.OAuthRefreshToken
|
||||
link.OAuthExpiry = params.OAuthExpiry
|
||||
|
||||
q.userLinks[i] = link
|
||||
return link, nil
|
||||
}
|
||||
}
|
||||
|
||||
return database.UserLink{}, sql.ErrNoRows
|
||||
}
|
||||
|
Reference in New Issue
Block a user