mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
fix: make 'NoRefresh' honor unlimited tokens in gitauth (#9472)
* chore: fix NoRefresh to honor unlimited tokens * improve testing coverage of gitauth * refactor rest of gitauth tests
This commit is contained in:
@ -7,6 +7,7 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -41,7 +42,7 @@ import (
|
||||
type FakeIDP struct {
|
||||
issuer string
|
||||
key *rsa.PrivateKey
|
||||
provider providerJSON
|
||||
provider ProviderJSON
|
||||
handler http.Handler
|
||||
cfg *oauth2.Config
|
||||
|
||||
@ -66,7 +67,7 @@ type FakeIDP struct {
|
||||
// IDP -> Application. Almost all IDPs have the concept of
|
||||
// "Authorized Redirect URLs". This can be used to emulate that.
|
||||
hookValidRedirectURL func(redirectURL string) error
|
||||
hookUserInfo func(email string) jwt.MapClaims
|
||||
hookUserInfo func(email string) (jwt.MapClaims, error)
|
||||
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||
hookOnRefresh func(email string) error
|
||||
// Custom authentication for the client. This is useful if you want
|
||||
@ -75,6 +76,26 @@ type FakeIDP struct {
|
||||
serve bool
|
||||
}
|
||||
|
||||
func StatusError(code int, err error) error {
|
||||
return statusHookError{
|
||||
Err: err,
|
||||
HTTPStatusCode: code,
|
||||
}
|
||||
}
|
||||
|
||||
// statusHookError allows a hook to change the returned http status code.
|
||||
type statusHookError struct {
|
||||
Err error
|
||||
HTTPStatusCode int
|
||||
}
|
||||
|
||||
func (s statusHookError) Error() string {
|
||||
if s.Err == nil {
|
||||
return ""
|
||||
}
|
||||
return s.Err.Error()
|
||||
}
|
||||
|
||||
type FakeIDPOpt func(idp *FakeIDP)
|
||||
|
||||
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
|
||||
@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
|
||||
}
|
||||
}
|
||||
|
||||
// WithRefreshHook is called when a refresh token is used. The email is
|
||||
// WithRefresh is called when a refresh token is used. The email is
|
||||
// the email of the user that is being refreshed assuming the claims are correct.
|
||||
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
|
||||
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookOnRefresh = hook
|
||||
}
|
||||
@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
||||
// every user on the /userinfo endpoint.
|
||||
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = func(_ string) jwt.MapClaims {
|
||||
return info
|
||||
f.hookUserInfo = func(_ string) (jwt.MapClaims, error) {
|
||||
return info, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
|
||||
func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = userInfoFunc
|
||||
}
|
||||
@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
hookOnRefresh: func(_ string) error { return nil },
|
||||
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
|
||||
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||
}
|
||||
|
||||
@ -181,6 +202,10 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
return idp
|
||||
}
|
||||
|
||||
func (f *FakeIDP) WellknownConfig() ProviderJSON {
|
||||
return f.provider
|
||||
}
|
||||
|
||||
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
t.Helper()
|
||||
|
||||
@ -188,9 +213,9 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
require.NoError(t, err, "invalid issuer URL")
|
||||
|
||||
f.issuer = issuer
|
||||
// providerJSON is the JSON representation of the OpenID Connect provider
|
||||
// ProviderJSON is the JSON representation of the OpenID Connect provider
|
||||
// These are all the urls that the IDP will respond to.
|
||||
f.provider = providerJSON{
|
||||
f.provider = ProviderJSON{
|
||||
Issuer: issuer,
|
||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||
@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||
return srv
|
||||
}
|
||||
|
||||
// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
|
||||
// valid token for some given claims.
|
||||
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
|
||||
state := uuid.NewString()
|
||||
f.stateToIDTokenClaims.Store(state, claims)
|
||||
code := f.newCode(state)
|
||||
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
|
||||
}
|
||||
|
||||
// Login does the full OIDC flow starting at the "LoginButton".
|
||||
// The client argument is just to get the URL of the Coder instance.
|
||||
//
|
||||
@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type providerJSON struct {
|
||||
// ProviderJSON is the .well-known/configuration JSON
|
||||
type ProviderJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
err := f.hookValidRedirectURL(redirectURI)
|
||||
if err != nil {
|
||||
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
|
||||
@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
slog.F("values", values.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
|
||||
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
getEmail := func(claims jwt.MapClaims) string {
|
||||
@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
claims = idTokenClaims
|
||||
err := f.hookOnRefresh(getEmail(claims))
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
|
||||
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
|
||||
@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
|
||||
claims, err := f.hookUserInfo(email)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(claims)
|
||||
}))
|
||||
|
||||
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
|
||||
return cfg
|
||||
}
|
||||
|
||||
func httpErrorCode(defaultCode int, err error) int {
|
||||
var stautsErr statusHookError
|
||||
status := defaultCode
|
||||
if errors.As(err, &stautsErr) {
|
||||
status = stautsErr.HTTPStatusCode
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct {
|
||||
roundTrip func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
Reference in New Issue
Block a user