fix: stop extending API key access if OIDC refresh is available (#17878)

fixes #17070

Cleans up our handling of APIKey expiration and OIDC to keep them separate concepts. For an OIDC-login APIKey, both the APIKey and OIDC link must be valid to login. If the OIDC link is expired and we have a refresh token, we will attempt to refresh.

OIDC refreshes do not have any effect on APIKey expiry.

https://github.com/coder/coder/issues/17070#issuecomment-2886183613 explains why this is the correct behavior.
This commit is contained in:
Spike Curtis
2025-05-19 12:05:35 +04:00
committed by GitHub
parent ca5a78adbf
commit 1a41608035
4 changed files with 210 additions and 48 deletions

View File

@ -307,7 +307,7 @@ func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values
// WithLogging is optional, but will log some HTTP calls made to the IDP.
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
return func(f *FakeIDP) {
f.logger = slogtest.Make(t, options)
f.logger = slogtest.Make(t, options).Named("fakeidp")
}
}
@ -794,6 +794,7 @@ func (f *FakeIDP) newToken(t testing.TB, email string, expires time.Time) string
func (f *FakeIDP) newRefreshTokens(email string) string {
refreshToken := uuid.NewString()
f.refreshTokens.Store(refreshToken, email)
f.logger.Info(context.Background(), "new refresh token", slog.F("email", email), slog.F("token", refreshToken))
return refreshToken
}
@ -1003,6 +1004,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
return
}
f.logger.Info(r.Context(), "http idp call refresh_token", slog.F("token", refreshToken))
_, ok := f.refreshTokens.Load(refreshToken)
if !assert.True(t, ok, "invalid refresh_token") {
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
@ -1026,6 +1028,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
f.refreshTokensUsed.Store(refreshToken, true)
// Always invalidate the refresh token after it is used.
f.refreshTokens.Delete(refreshToken)
f.logger.Info(r.Context(), "refresh token invalidated", slog.F("token", refreshToken))
case "urn:ietf:params:oauth:grant-type:device_code":
// Device flow
var resp externalauth.ExchangeDeviceCodeResponse

View File

@ -232,16 +232,21 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return optionalWrite(http.StatusUnauthorized, resp)
}
var (
link database.UserLink
now = dbtime.Now()
// Tracks if the API key has properties updated
changed = false
)
now := dbtime.Now()
if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
// refreshing the OIDC token.
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
var err error
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
@ -258,7 +263,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
// Check if the OAuth token is expired
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
if !link.OAuthExpiry.IsZero() && link.OAuthExpiry.Before(now) {
if cfg.OAuth2Configs.IsZero() {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
@ -267,12 +272,15 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
var friendlyName string
var oauthConfig promoauth.OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = cfg.OAuth2Configs.Github
friendlyName = "GitHub"
case database.LoginTypeOIDC:
oauthConfig = cfg.OAuth2Configs.OIDC
friendlyName = "OpenID Connect"
default:
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
@ -292,7 +300,13 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
// If it is, let's refresh it from the provided config
if link.OAuthRefreshToken == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("%s session expired at %q. Try signing in again.", friendlyName, link.OAuthExpiry.String()),
})
}
// We have a refresh token, so let's try it
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: link.OAuthAccessToken,
RefreshToken: link.OAuthRefreshToken,
@ -300,28 +314,39 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}).Token()
if err != nil {
return write(http.StatusUnauthorized, codersdk.Response{
Message: "Could not refresh expired Oauth token. Try re-authenticating to resolve this issue.",
Detail: err.Error(),
Message: fmt.Sprintf(
"Could not refresh expired %s token. Try re-authenticating to resolve this issue.",
friendlyName),
Detail: err.Error(),
})
}
link.OAuthAccessToken = token.AccessToken
link.OAuthRefreshToken = token.RefreshToken
link.OAuthExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}
}
// Checking if the key is expired.
// NOTE: The `RequireAuth` React component depends on this `Detail` to detect when
// the users token has expired. If you change the text here, make sure to update it
// in site/src/components/RequireAuth/RequireAuth.tsx as well.
if key.ExpiresAt.Before(now) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
}
// Tracks if the API key has properties updated
changed := false
// Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour {
@ -363,29 +388,6 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
Detail: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
})
}
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
//nolint:gocritic // system needs to update user link
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: link.OAuthExpiry,
// Refresh should keep the same debug context because we use
// the original claims for the group/role sync.
Claims: link.Claims,
})
if err != nil {
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("update user_link: %s.", err.Error()),
})
}
}
// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's

View File

@ -508,6 +508,102 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now().AddDate(0, 0, -1),
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
LoginType: database.LoginTypeOIDC,
})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
})
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, token)
// Include a valid oauth token for refreshing. If this token is invalid,
// it is difficult to tell an auth failure from an expired api key, or
// an expired oauth key.
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
OIDC: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now().AddDate(0, 0, -1),
ExpiresAt: dbtime.Now().AddDate(0, 0, -1),
LoginType: database.LoginTypeOIDC,
})
_ = dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID,
LoginType: database.LoginTypeOIDC,
})
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, token)
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
OIDC: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OAuthRefresh", func(t *testing.T) {
t.Parallel()
var (
@ -553,7 +649,67 @@ func TestAPIKey(t *testing.T) {
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, oauthToken.Expiry, gotAPIKey.ExpiresAt)
// Note that OAuth expiry is independent of APIKey expiry, so an OIDC refresh DOES NOT affect the expiry of the
// APIKey
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
gotLink, err := db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
})
require.NoError(t, err)
require.Equal(t, gotLink.OAuthRefreshToken, "moo")
})
t.Run("OAuthExpiredNoRefresh", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
db = dbmem.New()
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
LastUsed: dbtime.Now(),
ExpiresAt: dbtime.Now().AddDate(0, 0, 1),
LoginType: database.LoginTypeGithub,
})
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
_, err := db.InsertUserLink(ctx, database.InsertUserLinkParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
OAuthExpiry: dbtime.Now().AddDate(0, 0, -1),
OAuthAccessToken: "letmein",
})
require.NoError(t, err)
r.Header.Set(codersdk.SessionTokenHeader, token)
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
Github: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("RemoteIPUpdates", func(t *testing.T) {

View File

@ -144,6 +144,7 @@ func TestAzureAKPKIWithCoderd(t *testing.T) {
return values, nil
}),
oidctest.WithServing(),
oidctest.WithLogging(t, nil),
)
cfg := fake.OIDCConfig(t, scopes, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true