mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
fix: stop extending API key access if OIDC refresh is available (#17878)
fixes #17070 Cleans up our handling of APIKey expiration and OIDC to keep them separate concepts. For an OIDC-login APIKey, both the APIKey and OIDC link must be valid to login. If the OIDC link is expired and we have a refresh token, we will attempt to refresh. OIDC refreshes do not have any effect on APIKey expiry. https://github.com/coder/coder/issues/17070#issuecomment-2886183613 explains why this is the correct behavior.
This commit is contained in:
@ -307,7 +307,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values
|
||||
// WithLogging is optional, but will log some HTTP calls made to the IDP.
|
||||
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.logger = slogtest.Make(t, options)
|
||||
f.logger = slogtest.Make(t, options).Named("fakeidp")
|
||||
}
|
||||
}
|
||||
|
||||
@ -794,6 +794,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string
|
||||
func (f *FakeIDP) newRefreshTokens(email string) string {
|
||||
refreshToken := uuid.NewString()
|
||||
f.refreshTokens.Store(refreshToken, email)
|
||||
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken))
|
||||
return refreshToken
|
||||
}
|
||||
|
||||
@ -1003,6 +1004,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken))
|
||||
_, ok := f.refreshTokens.Load(refreshToken)
|
||||
if !assert.True(t, ok, "invalid refresh_token") {
|
||||
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
|
||||
@ -1026,6 +1028,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
f.refreshTokensUsed.Store(refreshToken, true)
|
||||
// Always invalidate the refresh token after it is used.
|
||||
f.refreshTokens.Delete(refreshToken)
|
||||
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken))
|
||||
case "urn:ietf:params:oauth:grant-type:device_code":
|
||||
// Device flow
|
||||
var resp externalauth.ExchangeDeviceCodeResponse
|
||||
|
@ -232,16 +232,21 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
return optionalWrite(http.StatusUnauthorized, resp)
|
||||
}
|
||||
|
||||
var (
|
||||
link database.UserLink
|
||||
now = dbtime.Now()
|
||||
// Tracks if the API key has properties updated
|
||||
changed = false
|
||||
)
|
||||
now := dbtime.Now()
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
})
|
||||
}
|
||||
|
||||
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
|
||||
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
|
||||
// refreshing the OIDC token.
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
var err error
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
@ -258,7 +263,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
})
|
||||
}
|
||||
// Check if the OAuth token is expired
|
||||
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
|
||||
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
|
||||
if cfg.OAuth2Configs.IsZero() {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
@ -267,12 +272,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
})
|
||||
}
|
||||
|
||||
var friendlyName string
|
||||
var oauthConfig promoauth.OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
friendlyName = "GitHub"
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
friendlyName = "OpenID Connect"
|
||||
default:
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
@ -292,7 +300,13 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
})
|
||||
}
|
||||
|
||||
// If it is, let's refresh it from the provided config
|
||||
if link.OAuthRefreshToken == "" {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
|
||||
})
|
||||
}
|
||||
// We have a refresh token, so let's try it
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
@ -300,28 +314,39 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
}).Token()
|
||||
if err != nil {
|
||||
return write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.",
|
||||
Detail: err.Error(),
|
||||
Message: fmt.Sprintf(
|
||||
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
|
||||
friendlyName),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
link.OAuthAccessToken = token.AccessToken
|
||||
link.OAuthRefreshToken = token.RefreshToken
|
||||
link.OAuthExpiry = token.Expiry
|
||||
key.ExpiresAt = token.Expiry
|
||||
changed = true
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Checking if the key is expired.
|
||||
// NOTE: The `RequireAuth` React component depends on this `Detail` to detect when
|
||||
// the users token has expired. If you change the text here, make sure to update it
|
||||
// in site/src/components/RequireAuth/RequireAuth.tsx as well.
|
||||
if key.ExpiresAt.Before(now) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
})
|
||||
}
|
||||
// Tracks if the API key has properties updated
|
||||
changed := false
|
||||
|
||||
// Only update LastUsed once an hour to prevent database spam.
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
@ -363,29 +388,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
// If the API Key is associated with a user_link (e.g. Github/OIDC)
|
||||
// then we want to update the relevant oauth fields.
|
||||
if link.UserID != uuid.Nil {
|
||||
//nolint:gocritic // system needs to update user link
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
// Refresh should keep the same debug context because we use
|
||||
// the original claims for the group/role sync.
|
||||
Claims: link.Claims,
|
||||
})
|
||||
if err != nil {
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// We only want to update this occasionally to reduce DB write
|
||||
// load. We update alongside the UserLink and APIKey since it's
|
||||
|
@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) {
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = dbmem.New()
|
||||
user = dbgen.User(t, db, database.User{})
|
||||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
LastUsed: dbtime.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
_ = dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
|
||||
})
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.Header.Set(codersdk.SessionTokenHeader, token)
|
||||
|
||||
// Include a valid oauth token for refreshing. If this token is invalid,
|
||||
// it is difficult to tell an auth failure from an expired api key, or
|
||||
// an expired oauth key.
|
||||
oauthToken := &oauth2.Token{
|
||||
AccessToken: "wow",
|
||||
RefreshToken: "moo",
|
||||
Expiry: dbtime.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
OAuth2Configs: &httpmw.OAuth2Configs{
|
||||
OIDC: &testutil.OAuth2Config{
|
||||
Token: oauthToken,
|
||||
},
|
||||
},
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = dbmem.New()
|
||||
user = dbgen.User(t, db, database.User{})
|
||||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
LastUsed: dbtime.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
_ = dbgen.UserLink(t, db, database.UserLink{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.Header.Set(codersdk.SessionTokenHeader, token)
|
||||
|
||||
oauthToken := &oauth2.Token{
|
||||
AccessToken: "wow",
|
||||
RefreshToken: "moo",
|
||||
Expiry: dbtime.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
OAuth2Configs: &httpmw.OAuth2Configs{
|
||||
OIDC: &testutil.OAuth2Config{
|
||||
Token: oauthToken,
|
||||
},
|
||||
},
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("OAuthRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt)
|
||||
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
|
||||
// APIKey
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
|
||||
gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, gotLink.OAuthRefreshToken, "moo")
|
||||
})
|
||||
|
||||
t.Run("OAuthExpiredNoRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
db = dbmem.New()
|
||||
user = dbgen.User(t, db, database.User{})
|
||||
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
LastUsed: dbtime.Now(),
|
||||
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
|
||||
LoginType: database.LoginTypeGithub,
|
||||
})
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
|
||||
OAuthAccessToken: "letmein",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.Header.Set(codersdk.SessionTokenHeader, token)
|
||||
|
||||
oauthToken := &oauth2.Token{
|
||||
AccessToken: "wow",
|
||||
RefreshToken: "moo",
|
||||
Expiry: dbtime.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
OAuth2Configs: &httpmw.OAuth2Configs{
|
||||
Github: &testutil.OAuth2Config{
|
||||
Token: oauthToken,
|
||||
},
|
||||
},
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("RemoteIPUpdates", func(t *testing.T) {
|
||||
|
@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) {
|
||||
return values, nil
|
||||
}),
|
||||
oidctest.WithServing(),
|
||||
oidctest.WithLogging(t, nil),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
|
Reference in New Issue
Block a user