mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
fix: use unique ID for linked accounts (#3441)
- move OAuth-related fields off of api_keys into a new user_links table - restrict users to single form of login - process updates to user email/usernames for OIDC - added a login_type column to users
This commit is contained in:
@ -14,6 +14,7 @@ import (
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tabbed/pqtype"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
@ -149,9 +150,21 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
// Tracks if the API key has properties updated!
|
||||
changed := false
|
||||
|
||||
var link database.UserLink
|
||||
if key.LoginType != database.LoginTypePassword {
|
||||
link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
if err != nil {
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "A database error occurred",
|
||||
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Check if the OAuth token is expired!
|
||||
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
|
||||
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() {
|
||||
var oauthConfig OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
@ -167,9 +180,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
}
|
||||
// If it is, let's refresh it from the provided config!
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: key.OAuthAccessToken,
|
||||
RefreshToken: key.OAuthRefreshToken,
|
||||
Expiry: key.OAuthExpiry,
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
Expiry: link.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
@ -178,9 +191,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
})
|
||||
return
|
||||
}
|
||||
key.OAuthAccessToken = token.AccessToken
|
||||
key.OAuthRefreshToken = token.RefreshToken
|
||||
key.OAuthExpiry = token.Expiry
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
key.ExpiresAt = token.Expiry
|
||||
changed = true
|
||||
}
|
||||
@ -222,13 +235,10 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
}
|
||||
if changed {
|
||||
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
OAuthAccessToken: key.OAuthAccessToken,
|
||||
OAuthRefreshToken: key.OAuthRefreshToken,
|
||||
OAuthExpiry: key.OAuthExpiry,
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
IPAddress: key.IPAddress,
|
||||
})
|
||||
if err != nil {
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
@ -237,6 +247,24 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
})
|
||||
return
|
||||
}
|
||||
// If the API Key is associated with a user_link (e.g. Github/OIDC)
|
||||
// then we want to update the relevant oauth fields.
|
||||
if link.UserID != uuid.Nil {
|
||||
link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
})
|
||||
if err != nil {
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the key is valid, we also fetch the user roles and status.
|
||||
|
@ -187,6 +187,7 @@ func TestAPIKey(t *testing.T) {
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
@ -215,6 +216,7 @@ func TestAPIKey(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -253,6 +255,7 @@ func TestAPIKey(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -288,6 +291,7 @@ func TestAPIKey(t *testing.T) {
|
||||
LastUsed: database.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
@ -323,6 +327,7 @@ func TestAPIKey(t *testing.T) {
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
@ -361,6 +366,13 @@ func TestAPIKey(t *testing.T) {
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
@ -393,10 +405,16 @@ func TestAPIKey(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeGithub,
|
||||
LastUsed: database.Now(),
|
||||
OAuthExpiry: database.Now().AddDate(0, 0, -1),
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
OAuthExpiry: database.Now().AddDate(0, 0, -1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "wow",
|
||||
RefreshToken: "moo",
|
||||
@ -418,7 +436,6 @@ func TestAPIKey(t *testing.T) {
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
|
||||
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
|
||||
})
|
||||
|
||||
t.Run("RemoteIPUpdates", func(t *testing.T) {
|
||||
@ -443,6 +460,7 @@ func TestAPIKey(t *testing.T) {
|
||||
LastUsed: database.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
|
@ -124,6 +124,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -53,6 +53,7 @@ func TestOrganizationParam(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
|
@ -53,6 +53,7 @@ func TestTemplateParam(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -53,6 +53,7 @@ func TestTemplateVersionParam(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -47,6 +47,7 @@ func TestUserParam(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -53,6 +53,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -53,6 +53,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -53,6 +53,7 @@ func TestWorkspaceParam(t *testing.T) {
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
Reference in New Issue
Block a user