mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
fix: remove refresh oauth logic on OIDC login (#8950)
* fix: do not do oauth refresh logic on oidc login
This commit is contained in:
@ -693,7 +693,6 @@ func New(options *Options) *API {
|
||||
r.Route("/github", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil),
|
||||
apiKeyMiddlewareOptional,
|
||||
)
|
||||
r.Get("/callback", api.userOAuth2Github)
|
||||
})
|
||||
@ -701,7 +700,6 @@ func New(options *Options) *API {
|
||||
r.Route("/oidc/callback", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams),
|
||||
apiKeyMiddlewareOptional,
|
||||
)
|
||||
r.Get("/", api.userOIDC)
|
||||
})
|
||||
|
@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
|
||||
type OIDCConfig struct {
|
||||
key *rsa.PrivateKey
|
||||
issuer string
|
||||
// These are optional
|
||||
refreshToken string
|
||||
oidcTokenExpires func() time.Time
|
||||
tokenSource func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
|
||||
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.refreshToken = token
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.oidcTokenExpires = expFunc
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.tokenSource = src
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
@ -1035,33 +1057,58 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
|
||||
issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
return &OIDCConfig{
|
||||
cfg := &OIDCConfig{
|
||||
key: pkey,
|
||||
issuer: issuer,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return nil
|
||||
type tokenSource struct {
|
||||
src func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
func (s tokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.src()
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
if cfg.tokenSource == nil {
|
||||
return nil
|
||||
}
|
||||
return tokenSource{
|
||||
src: cfg.tokenSource,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
token, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode code: %w", err)
|
||||
}
|
||||
|
||||
var exp time.Time
|
||||
if cfg.oidcTokenExpires != nil {
|
||||
exp = cfg.oidcTokenExpires()
|
||||
}
|
||||
|
||||
return (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
AccessToken: "token",
|
||||
RefreshToken: cfg.refreshToken,
|
||||
Expiry: exp,
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": string(token),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
@ -1069,20 +1116,20 @@ func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = o.issuer
|
||||
claims["iss"] = cfg.issuer
|
||||
}
|
||||
|
||||
if _, ok := claims["sub"]; !ok {
|
||||
claims["sub"] = "testme"
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(signed))
|
||||
}
|
||||
|
||||
func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
// By default, the provider can be empty.
|
||||
// This means it won't support any endpoints!
|
||||
provider := &oidc.Provider{}
|
||||
@ -1099,10 +1146,10 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
|
||||
}
|
||||
provider = cfg.NewProvider(context.Background())
|
||||
}
|
||||
cfg := &coderd.OIDCConfig{
|
||||
OAuth2Config: o,
|
||||
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{o.key.Public()},
|
||||
newCFG := &coderd.OIDCConfig{
|
||||
OAuth2Config: cfg,
|
||||
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
}),
|
||||
@ -1113,9 +1160,9 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
|
||||
GroupField: "groups",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
opt(newCFG)
|
||||
}
|
||||
return cfg
|
||||
return newCFG
|
||||
}
|
||||
|
||||
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
||||
|
@ -142,6 +142,56 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
|
||||
tokenFunc := APITokenFromRequest
|
||||
if sessionTokenFunc != nil {
|
||||
tokenFunc = sessionTokenFunc
|
||||
}
|
||||
|
||||
token := tokenFunc(r)
|
||||
if token == "" {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
|
||||
}, false
|
||||
}
|
||||
|
||||
keyID, keySecret, err := SplitAPIToken(token)
|
||||
if err != nil {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "Invalid API key format: " + err.Error(),
|
||||
}, false
|
||||
}
|
||||
|
||||
//nolint:gocritic // System needs to fetch API key to check if it's valid.
|
||||
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key is invalid.",
|
||||
}, false
|
||||
}
|
||||
|
||||
return nil, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
|
||||
}, false
|
||||
}
|
||||
|
||||
// Checking to see if the secret is valid.
|
||||
hashedSecret := sha256.Sum256([]byte(keySecret))
|
||||
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
|
||||
return nil, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key secret is invalid.",
|
||||
}, false
|
||||
}
|
||||
|
||||
return &key, codersdk.Response{}, true
|
||||
}
|
||||
|
||||
// ExtractAPIKey requires authentication using a valid API key. It handles
|
||||
// extending an API key if it comes close to expiry, updating the last used time
|
||||
// in the database.
|
||||
@ -179,49 +229,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tokenFunc := APITokenFromRequest
|
||||
if cfg.SessionTokenFunc != nil {
|
||||
tokenFunc = cfg.SessionTokenFunc
|
||||
}
|
||||
token := tokenFunc(r)
|
||||
if token == "" {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
|
||||
})
|
||||
}
|
||||
|
||||
keyID, keySecret, err := SplitAPIToken(token)
|
||||
if err != nil {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "Invalid API key format: " + err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:gocritic // System needs to fetch API key to check if it's valid.
|
||||
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key is invalid.",
|
||||
})
|
||||
}
|
||||
|
||||
return write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
// Checking to see if the secret is valid.
|
||||
hashedSecret := sha256.Sum256([]byte(keySecret))
|
||||
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
|
||||
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: SignedOutErrorMessage,
|
||||
Detail: "API key secret is invalid.",
|
||||
})
|
||||
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
|
||||
if !ok {
|
||||
return optionalWrite(http.StatusUnauthorized, resp)
|
||||
}
|
||||
|
||||
var (
|
||||
@ -232,7 +242,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
)
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
@ -427,7 +437,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
}.WithCachedASTValue(),
|
||||
}
|
||||
|
||||
return &key, &authz, true
|
||||
return key, &authz, true
|
||||
}
|
||||
|
||||
// APITokenFromRequest returns the api token from the request.
|
||||
|
@ -1427,7 +1427,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
||||
}
|
||||
|
||||
var key database.APIKey
|
||||
if oldKey, ok := httpmw.APIKeyOptional(r); ok && isConvertLoginType {
|
||||
oldKey, _, ok := httpmw.APIKeyFromRequest(ctx, api.Database, nil, r)
|
||||
if ok && oldKey != nil && isConvertLoginType {
|
||||
// If this is a convert login type, and it succeeds, then delete the old
|
||||
// session. Force the user to log back in.
|
||||
err := api.Database.DeleteAPIKeyByID(r.Context(), oldKey.ID)
|
||||
@ -1447,7 +1448,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
||||
Secure: api.SecureAuthCookie,
|
||||
HttpOnly: true,
|
||||
})
|
||||
key = oldKey
|
||||
// This is intentional setting the key to the deleted old key,
|
||||
// as the user needs to be forced to log back in.
|
||||
key = *oldKey
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
cookie, newKey, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), apikey.CreateParams{
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"net/http/cookiejar"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
@ -24,12 +25,97 @@ import (
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/dbgen"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
// This test specifically tests logging in with OIDC when an expired
|
||||
// OIDC session token exists.
|
||||
// The token refreshing should not happen since we are reauthenticating.
|
||||
func TestOIDCOauthLoginWithExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := coderdtest.NewOIDCConfig(t, "",
|
||||
// Provide a refresh token so we use the refresh token flow
|
||||
coderdtest.WithRefreshToken("refresh_token"),
|
||||
// We need to set the expire in the future for the first api calls.
|
||||
coderdtest.WithTokenExpires(func() time.Time {
|
||||
return time.Now().Add(time.Hour).UTC()
|
||||
}),
|
||||
// No refresh should actually happen in this test.
|
||||
coderdtest.WithTokenSource(func() (*oauth2.Token, error) {
|
||||
return nil, xerrors.New("token should not require refresh")
|
||||
}),
|
||||
)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
auditor := audit.NewMock()
|
||||
const username = "alice"
|
||||
claims := jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"email_verified": true,
|
||||
"preferred_username": username,
|
||||
}
|
||||
config := conf.OIDCConfig(t, claims)
|
||||
|
||||
config.AllowSignups = true
|
||||
config.IgnoreUserInfo = true
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: config,
|
||||
Logger: &logger,
|
||||
})
|
||||
|
||||
// Signup alice
|
||||
resp := oidcCallback(t, client, conf.EncodeClaims(t, claims))
|
||||
// Set the client to use this OIDC context
|
||||
authCookie := authCookieValue(resp.Cookies())
|
||||
client.SetSessionToken(authCookie)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
// Verify the user and oauth link
|
||||
user, err := client.User(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, username, user.Username)
|
||||
|
||||
// nolint:gocritic
|
||||
link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginType(user.LoginType),
|
||||
})
|
||||
require.NoError(t, err, "failed to get user link")
|
||||
|
||||
// Expire the link
|
||||
// nolint:gocritic
|
||||
_, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(),
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err, "failed to update user link")
|
||||
|
||||
// Log in again with OIDC
|
||||
loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) {
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: authCookie,
|
||||
Path: "/",
|
||||
})
|
||||
})
|
||||
require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode)
|
||||
_ = loginAgain.Body.Close()
|
||||
|
||||
// Try to use new login
|
||||
client.SetSessionToken(authCookieValue(resp.Cookies()))
|
||||
_, err = client.User(ctx, "me")
|
||||
require.NoError(t, err, "use new session")
|
||||
}
|
||||
|
||||
func TestUserLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
@ -819,7 +905,7 @@ func TestUserOIDC(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString)
|
||||
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
|
||||
@ -1045,10 +1131,10 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
}
|
||||
|
||||
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
|
||||
return oidcCallbackWithState(t, client, code, "somestate")
|
||||
return oidcCallbackWithState(t, client, code, "somestate", nil)
|
||||
}
|
||||
|
||||
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string) *http.Response {
|
||||
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
@ -1062,6 +1148,9 @@ func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state st
|
||||
Name: codersdk.OAuth2StateCookie,
|
||||
Value: state,
|
||||
})
|
||||
if modify != nil {
|
||||
modify(req)
|
||||
}
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
Reference in New Issue
Block a user