fix: use unique ID for linked accounts (#3441)

- move OAuth-related fields off of api_keys into a new user_links table
- restrict users to single form of login
- process updates to user email/usernames for OIDC
- added a login_type column to users
This commit is contained in:
Jon Ayers
2022-08-17 18:00:53 -05:00
committed by GitHub
parent 53d1fb36db
commit c3eea98db0
29 changed files with 931 additions and 266 deletions

View File

@ -14,6 +14,7 @@ import (
"golang.org/x/oauth2"
"github.com/google/uuid"
"github.com/tabbed/pqtype"
"github.com/coder/coder/coderd/database"
@ -149,9 +150,21 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
// Tracks if the API key has properties updated!
changed := false
var link database.UserLink
if key.LoginType != database.LoginTypePassword {
link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
if err != nil {
write(http.StatusInternalServerError, codersdk.Response{
Message: "A database error occurred",
Detail: fmt.Sprintf("get user link by user ID and login type: %s", err.Error()),
})
return
}
// Check if the OAuth token is expired!
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() {
var oauthConfig OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
@ -167,9 +180,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
}
// If it is, let's refresh it from the provided config!
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: key.OAuthAccessToken,
RefreshToken: key.OAuthRefreshToken,
Expiry: key.OAuthExpiry,
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
Expiry: link.OAuthExpiry,
}).Token()
if err != nil {
write(http.StatusUnauthorized, codersdk.Response{
@ -178,9 +191,9 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
})
return
}
key.OAuthAccessToken = token.AccessToken
key.OAuthRefreshToken = token.RefreshToken
key.OAuthExpiry = token.Expiry
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
}
@ -222,13 +235,10 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
}
if changed {
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
OAuthAccessToken: key.OAuthAccessToken,
OAuthRefreshToken: key.OAuthRefreshToken,
OAuthExpiry: key.OAuthExpiry,
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
IPAddress: key.IPAddress,
})
if err != nil {
write(http.StatusInternalServerError, codersdk.Response{
@ -237,6 +247,24 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
})
return
}
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthExpiry: link.OAuthExpiry,
})
if err != nil {
write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
return
}
}
}
// If the key is valid, we also fetch the user roles and status.

View File

@ -187,6 +187,7 @@ func TestAPIKey(t *testing.T) {
ID: id,
HashedSecret: hashed[:],
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
@ -215,6 +216,7 @@ func TestAPIKey(t *testing.T) {
HashedSecret: hashed[:],
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -253,6 +255,7 @@ func TestAPIKey(t *testing.T) {
HashedSecret: hashed[:],
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -288,6 +291,7 @@ func TestAPIKey(t *testing.T) {
LastUsed: database.Now().AddDate(0, 0, -1),
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
@ -323,6 +327,7 @@ func TestAPIKey(t *testing.T) {
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
@ -361,6 +366,13 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
@ -393,10 +405,16 @@ func TestAPIKey(t *testing.T) {
HashedSecret: hashed[:],
LoginType: database.LoginTypeGithub,
LastUsed: database.Now(),
OAuthExpiry: database.Now().AddDate(0, 0, -1),
UserID: user.ID,
})
require.NoError(t, err)
_, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
OAuthExpiry: database.Now().AddDate(0, 0, -1),
})
require.NoError(t, err)
token := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
@ -418,7 +436,6 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
})
t.Run("RemoteIPUpdates", func(t *testing.T) {
@ -443,6 +460,7 @@ func TestAPIKey(t *testing.T) {
LastUsed: database.Now().AddDate(0, 0, -1),
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)

View File

@ -124,6 +124,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestOrganizationParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))

View File

@ -53,6 +53,7 @@ func TestTemplateParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestTemplateVersionParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -47,6 +47,7 @@ func TestUserParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)

View File

@ -53,6 +53,7 @@ func TestWorkspaceParam(t *testing.T) {
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
LoginType: database.LoginTypePassword,
})
require.NoError(t, err)