mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
fix: make 'NoRefresh' honor unlimited tokens in gitauth (#9472)
* chore: fix NoRefresh to honor unlimited tokens * improve testing coverage of gitauth * refactor rest of gitauth tests
This commit is contained in:
@ -3,18 +3,22 @@ package gitauth_test
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/gitauth"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@ -22,17 +26,70 @@ import (
|
||||
|
||||
func TestRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("FalseIfNoRefresh", func(t *testing.T) {
|
||||
const providerID = "test-idp"
|
||||
expired := time.Now().Add(time.Hour * -1)
|
||||
|
||||
t.Run("NoRefreshExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
NoRefresh: true,
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthExpiry: time.Time{},
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but NoRefresh was set")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
// The IDP should not be contacted since the token is expired. An expired
|
||||
// token with 'NoRefresh' should early abort.
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
t.Error("token was validated, but it was expired and this should never have happened.")
|
||||
return nil, xerrors.New("should not be called")
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.NoRefresh = true
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
// Expire the link
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, refreshed, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.False(t, refreshed)
|
||||
})
|
||||
|
||||
// NoRefreshNoExpiry tests that an oauth token without an expiry is always valid.
|
||||
// The "validate url" should be hit, but the refresh endpoint should not.
|
||||
t.Run("NoRefreshNoExpiry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but NoRefresh was set")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validated = true
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.NoRefresh = true
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
|
||||
// Zero time used
|
||||
link.OAuthExpiry = time.Time{}
|
||||
_, refreshed, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, refreshed, "token without expiry is always valid")
|
||||
require.True(t, validated, "token should have been validated")
|
||||
})
|
||||
|
||||
t.Run("FalseIfTokenSourceFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
@ -42,111 +99,167 @@ func TestRefreshToken(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthExpiry: expired,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, refreshed)
|
||||
})
|
||||
|
||||
t.Run("ValidateServerError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Failure"))
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
require.ErrorContains(t, err, "Failure")
|
||||
|
||||
const staticError = "static error"
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validated = true
|
||||
return jwt.MapClaims{}, xerrors.New(staticError)
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, _, err := config.RefreshToken(ctx, nil, link)
|
||||
require.ErrorContains(t, err, staticError)
|
||||
require.True(t, validated, "token should have been attempted to be validated")
|
||||
})
|
||||
|
||||
// ValidateFailure tests if the token is no longer valid with a 401 response.
|
||||
t.Run("ValidateFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Not permitted"))
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
require.NoError(t, err)
|
||||
|
||||
const staticError = "static error"
|
||||
validated := false
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validated = true
|
||||
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
},
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
_, refreshed, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err, staticError)
|
||||
require.False(t, refreshed)
|
||||
require.True(t, validated, "token should have been attempted to be validated")
|
||||
})
|
||||
|
||||
t.Run("ValidateRetryGitHub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hit := false
|
||||
// We need to ensure that the exponential backoff kicks in properly.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !hit {
|
||||
hit = true
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Not permitted"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
ID: "test",
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "updated",
|
||||
},
|
||||
|
||||
const staticError = "static error"
|
||||
validateCalls := 0
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but the token is not expired")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validateCalls++
|
||||
// Make the first call return a 401, subsequent calls should return a 200.
|
||||
if validateCalls > 1 {
|
||||
return jwt.MapClaims{}, nil
|
||||
}
|
||||
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.Type = codersdk.GitProviderGitHub
|
||||
},
|
||||
ValidateURL: srv.URL,
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
}
|
||||
db := dbfake.New()
|
||||
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
ProviderID: config.ID,
|
||||
OAuthAccessToken: "initial",
|
||||
})
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), db, link)
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
// Unlimited lifetime, this is what GitHub returns tokens as
|
||||
link.OAuthExpiry = time.Time{}
|
||||
|
||||
_, ok, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, refreshed)
|
||||
require.True(t, hit)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 2, validateCalls, "token should have been attempted to be validated more than once")
|
||||
})
|
||||
|
||||
t.Run("ValidateNoUpdate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
validated := make(chan struct{})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
close(validated)
|
||||
}))
|
||||
accessToken := "testing"
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
},
|
||||
|
||||
validateCalls := 0
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
t.Error("refresh on the IDP was called, but the token is not expired")
|
||||
return xerrors.New("should not be called")
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validateCalls++
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.Type = codersdk.GitProviderGitHub
|
||||
},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthAccessToken: accessToken,
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
|
||||
_, ok, err := config.RefreshToken(ctx, nil, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, valid)
|
||||
<-validated
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, validateCalls, "token is validated")
|
||||
})
|
||||
|
||||
// A token update comes from a refresh.
|
||||
t.Run("Updates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
ID: "test",
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "updated",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
db := dbfake.New()
|
||||
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
ProviderID: config.ID,
|
||||
OAuthAccessToken: "initial",
|
||||
validateCalls := 0
|
||||
refreshCalls := 0
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithRefresh(func(_ string) error {
|
||||
refreshCalls++
|
||||
return nil
|
||||
}),
|
||||
oidctest.WithDynamicUserInfo(func(_ string) (jwt.MapClaims, error) {
|
||||
validateCalls++
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *gitauth.Config) {
|
||||
cfg.Type = codersdk.GitProviderGitHub
|
||||
},
|
||||
DB: db,
|
||||
})
|
||||
_, valid, err := config.RefreshToken(context.Background(), db, link)
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
// Force a refresh
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
updated, ok, err := config.RefreshToken(ctx, db, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, valid)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, validateCalls, "token is validated")
|
||||
require.Equal(t, 1, refreshCalls, "token is refreshed")
|
||||
require.NotEqualf(t, link.OAuthAccessToken, updated.OAuthAccessToken, "token is updated")
|
||||
//nolint:gocritic // testing
|
||||
dbLink, err := db.GetGitAuthLink(dbauthz.AsSystemRestricted(context.Background()), database.GetGitAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
|
||||
})
|
||||
}
|
||||
|
||||
@ -232,3 +345,65 @@ func TestConvertYAML(t *testing.T) {
|
||||
require.Equal(t, "https://auth.com?client_id=id&redirect_uri=%2Fgitauth%2Fgitlab%2Fcallback&response_type=code&scope=read", config[0].AuthCodeURL(""))
|
||||
})
|
||||
}
|
||||
|
||||
type testConfig struct {
|
||||
FakeIDPOpts []oidctest.FakeIDPOpt
|
||||
CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig)
|
||||
GitConfigOpt func(cfg *gitauth.Config)
|
||||
// If DB is passed in, the link will be inserted into the DB.
|
||||
DB database.Store
|
||||
}
|
||||
|
||||
// setupTest will configure a fake IDP and a gitauth.Config for testing.
|
||||
// The Fake's userinfo endpoint is used for validating tokens.
|
||||
// No http servers are started so use the fake IDP's HTTPClient to make requests.
|
||||
// The returned token is a fully valid token for the IDP. Feel free to manipulate it
|
||||
// to test different scenarios.
|
||||
func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *gitauth.Config, database.GitAuthLink) {
|
||||
t.Helper()
|
||||
|
||||
const providerID = "test-idp"
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
append([]oidctest.FakeIDPOpt{}, settings.FakeIDPOpts...)...,
|
||||
)
|
||||
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: fake.OIDCConfig(t, nil, settings.CoderOIDCConfigOpts...),
|
||||
ID: providerID,
|
||||
ValidateURL: fake.WellknownConfig().UserInfoURL,
|
||||
}
|
||||
settings.GitConfigOpt(config)
|
||||
|
||||
oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
|
||||
"email": "test@coder.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Now()
|
||||
link := database.GitAuthLink{
|
||||
ProviderID: providerID,
|
||||
UserID: uuid.New(),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
OAuthAccessToken: oauthToken.AccessToken,
|
||||
OAuthRefreshToken: oauthToken.RefreshToken,
|
||||
// The caller can manually expire this if they want.
|
||||
OAuthExpiry: now.Add(time.Hour),
|
||||
}
|
||||
|
||||
if settings.DB != nil {
|
||||
// Feel free to insert additional things like the user, etc if required.
|
||||
link, err = settings.DB.InsertGitAuthLink(context.Background(), database.InsertGitAuthLinkParams{
|
||||
ProviderID: link.ProviderID,
|
||||
UserID: link.UserID,
|
||||
CreatedAt: link.CreatedAt,
|
||||
UpdatedAt: link.UpdatedAt,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: link.OAuthExpiry,
|
||||
})
|
||||
require.NoError(t, err, "failed to insert link into DB")
|
||||
}
|
||||
|
||||
return fake, config, link
|
||||
}
|
||||
|
Reference in New Issue
Block a user