fix: remove refresh oauth logic on OIDC login (#8950)

* fix: do not do oauth refresh logic on oidc login
This commit is contained in:
Steven Masley
2023-08-08 10:05:12 -05:00
committed by GitHub
parent 1d4a72f43f
commit 5339a31532
6 changed files with 217 additions and 68 deletions

View File

@ -693,7 +693,6 @@ func New(options *Options) *API {
r.Route("/github", func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil),
apiKeyMiddlewareOptional,
)
r.Get("/callback", api.userOAuth2Github)
})
@ -701,7 +700,6 @@ func New(options *Options) *API {
r.Route("/oidc/callback", func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams),
apiKeyMiddlewareOptional,
)
r.Get("/", api.userOIDC)
})

View File

@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
type OIDCConfig struct {
key *rsa.PrivateKey
issuer string
// These are optional
refreshToken string
oidcTokenExpires func() time.Time
tokenSource func() (*oauth2.Token, error)
}
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.refreshToken = token
}
}
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.oidcTokenExpires = expFunc
}
}
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.tokenSource = src
}
}
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
t.Helper()
block, _ := pem.Decode([]byte(testRSAPrivateKey))
@ -1035,33 +1057,58 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
issuer = "https://coder.com"
}
return &OIDCConfig{
cfg := &OIDCConfig{
key: pkey,
issuer: issuer,
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return nil
type tokenSource struct {
src func() (*oauth2.Token, error)
}
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
func (s tokenSource) Token() (*oauth2.Token, error) {
return s.src()
}
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
if cfg.tokenSource == nil {
return nil
}
return tokenSource{
src: cfg.tokenSource,
}
}
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
token, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, xerrors.Errorf("decode code: %w", err)
}
var exp time.Time
if cfg.oidcTokenExpires != nil {
exp = cfg.oidcTokenExpires()
}
return (&oauth2.Token{
AccessToken: "token",
AccessToken: "token",
RefreshToken: cfg.refreshToken,
Expiry: exp,
}).WithExtra(map[string]interface{}{
"id_token": string(token),
}), nil
}
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
t.Helper()
if _, ok := claims["exp"]; !ok {
@ -1069,20 +1116,20 @@ func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
}
if _, ok := claims["iss"]; !ok {
claims["iss"] = o.issuer
claims["iss"] = cfg.issuer
}
if _, ok := claims["sub"]; !ok {
claims["sub"] = "testme"
}
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString([]byte(signed))
}
func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
// By default, the provider can be empty.
// This means it won't support any endpoints!
provider := &oidc.Provider{}
@ -1099,10 +1146,10 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
}
provider = cfg.NewProvider(context.Background())
}
cfg := &coderd.OIDCConfig{
OAuth2Config: o,
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{o.key.Public()},
newCFG := &coderd.OIDCConfig{
OAuth2Config: cfg,
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
}, &oidc.Config{
SkipClientIDCheck: true,
}),
@ -1113,9 +1160,9 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
GroupField: "groups",
}
for _, opt := range opts {
opt(cfg)
opt(newCFG)
}
return cfg
return newCFG
}
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking

View File

@ -142,6 +142,56 @@ func ExtractAPIKeyMW(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
}
}
func APIKeyFromRequest(ctx context.Context, db database.Store, sessionTokenFunc func(r *http.Request) string, r *http.Request) (*database.APIKey, codersdk.Response, bool) {
tokenFunc := APITokenFromRequest
if sessionTokenFunc != nil {
tokenFunc = sessionTokenFunc
}
token := tokenFunc(r)
if token == "" {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
}, false
}
keyID, keySecret, err := SplitAPIToken(token)
if err != nil {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
}, false
}
//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
}, false
}
return nil, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
}, false
}
// Checking to see if the secret is valid.
hashedSecret := sha256.Sum256([]byte(keySecret))
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
return nil, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
}, false
}
return &key, codersdk.Response{}, true
}
// ExtractAPIKey requires authentication using a valid API key. It handles
// extending an API key if it comes close to expiry, updating the last used time
// in the database.
@ -179,49 +229,9 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return nil, nil, false
}
tokenFunc := APITokenFromRequest
if cfg.SessionTokenFunc != nil {
tokenFunc = cfg.SessionTokenFunc
}
token := tokenFunc(r)
if token == "" {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenCookie),
})
}
keyID, keySecret, err := SplitAPIToken(token)
if err != nil {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "Invalid API key format: " + err.Error(),
})
}
//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key is invalid.",
})
}
return write(http.StatusInternalServerError, codersdk.Response{
Message: internalErrorMessage,
Detail: fmt.Sprintf("Internal error fetching API key by id. %s", err.Error()),
})
}
// Checking to see if the secret is valid.
hashedSecret := sha256.Sum256([]byte(keySecret))
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
return optionalWrite(http.StatusUnauthorized, codersdk.Response{
Message: SignedOutErrorMessage,
Detail: "API key secret is invalid.",
})
key, resp, ok := APIKeyFromRequest(ctx, cfg.DB, cfg.SessionTokenFunc, r)
if !ok {
return optionalWrite(http.StatusUnauthorized, resp)
}
var (
@ -232,7 +242,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
)
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
link, err := cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
@ -427,7 +437,7 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
}.WithCachedASTValue(),
}
return &key, &authz, true
return key, &authz, true
}
// APITokenFromRequest returns the api token from the request.

View File

@ -1427,7 +1427,8 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}
var key database.APIKey
if oldKey, ok := httpmw.APIKeyOptional(r); ok && isConvertLoginType {
oldKey, _, ok := httpmw.APIKeyFromRequest(ctx, api.Database, nil, r)
if ok && oldKey != nil && isConvertLoginType {
// If this is a convert login type, and it succeeds, then delete the old
// session. Force the user to log back in.
err := api.Database.DeleteAPIKeyByID(r.Context(), oldKey.ID)
@ -1447,7 +1448,9 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
Secure: api.SecureAuthCookie,
HttpOnly: true,
})
key = oldKey
// This is intentional setting the key to the deleted old key,
// as the user needs to be forced to log back in.
key = *oldKey
} else {
//nolint:gocritic
cookie, newKey, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), apikey.CreateParams{

View File

@ -9,6 +9,7 @@ import (
"net/http/cookiejar"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
@ -24,12 +25,97 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
// This test specifically tests logging in with OIDC when an expired
// OIDC session token exists.
// The token refreshing should not happen since we are reauthenticating.
func TestOIDCOauthLoginWithExisting(t *testing.T) {
t.Parallel()
conf := coderdtest.NewOIDCConfig(t, "",
// Provide a refresh token so we use the refresh token flow
coderdtest.WithRefreshToken("refresh_token"),
// We need to set the expire in the future for the first api calls.
coderdtest.WithTokenExpires(func() time.Time {
return time.Now().Add(time.Hour).UTC()
}),
// No refresh should actually happen in this test.
coderdtest.WithTokenSource(func() (*oauth2.Token, error) {
return nil, xerrors.New("token should not require refresh")
}),
)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
auditor := audit.NewMock()
const username = "alice"
claims := jwt.MapClaims{
"email": "alice@coder.com",
"email_verified": true,
"preferred_username": username,
}
config := conf.OIDCConfig(t, claims)
config.AllowSignups = true
config.IgnoreUserInfo = true
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: config,
Logger: &logger,
})
// Signup alice
resp := oidcCallback(t, client, conf.EncodeClaims(t, claims))
// Set the client to use this OIDC context
authCookie := authCookieValue(resp.Cookies())
client.SetSessionToken(authCookie)
_ = resp.Body.Close()
ctx := testutil.Context(t, testutil.WaitLong)
// Verify the user and oauth link
user, err := client.User(ctx, "me")
require.NoError(t, err)
require.Equal(t, username, user.Username)
// nolint:gocritic
link, err := api.Database.GetUserLinkByUserIDLoginType(dbauthz.AsSystemRestricted(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: user.ID,
LoginType: database.LoginType(user.LoginType),
})
require.NoError(t, err, "failed to get user link")
// Expire the link
// nolint:gocritic
_, err = api.Database.UpdateUserLink(dbauthz.AsSystemRestricted(ctx), database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthExpiry: time.Now().Add(time.Hour * -1).UTC(),
UserID: link.UserID,
LoginType: link.LoginType,
})
require.NoError(t, err, "failed to update user link")
// Log in again with OIDC
loginAgain := oidcCallbackWithState(t, client, conf.EncodeClaims(t, claims), "seconds_login", func(req *http.Request) {
req.AddCookie(&http.Cookie{
Name: codersdk.SessionTokenCookie,
Value: authCookie,
Path: "/",
})
})
require.Equal(t, http.StatusTemporaryRedirect, loginAgain.StatusCode)
_ = loginAgain.Body.Close()
// Try to use new login
client.SetSessionToken(authCookieValue(resp.Cookies()))
_, err = client.User(ctx, "me")
require.NoError(t, err, "use new session")
}
func TestUserLogin(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
@ -819,7 +905,7 @@ func TestUserOIDC(t *testing.T) {
})
require.NoError(t, err)
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString)
resp := oidcCallbackWithState(t, user, code, convertResponse.StateString, nil)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
@ -1045,10 +1131,10 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
}
func oidcCallback(t *testing.T, client *codersdk.Client, code string) *http.Response {
return oidcCallbackWithState(t, client, code, "somestate")
return oidcCallbackWithState(t, client, code, "somestate", nil)
}
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string) *http.Response {
func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state string, modify func(r *http.Request)) *http.Response {
t.Helper()
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
@ -1062,6 +1148,9 @@ func oidcCallbackWithState(t *testing.T, client *codersdk.Client, code, state st
Name: codersdk.OAuth2StateCookie,
Value: state,
})
if modify != nil {
modify(req)
}
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()