fix: make 'NoRefresh' honor unlimited tokens in gitauth (#9472)

* chore: fix NoRefresh to honor unlimited tokens
* improve testing coverage of gitauth
* refactor rest of gitauth tests
This commit is contained in:
Steven Masley
2023-09-05 09:08:04 -05:00
committed by GitHub
parent da0ef92f77
commit 58f7071569
5 changed files with 354 additions and 112 deletions

View File

@ -7,6 +7,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
@ -41,7 +42,7 @@ import (
type FakeIDP struct {
issuer string
key *rsa.PrivateKey
provider providerJSON
provider ProviderJSON
handler http.Handler
cfg *oauth2.Config
@ -66,7 +67,7 @@ type FakeIDP struct {
// IDP -> Application. Almost all IDPs have the concept of
// "Authorized Redirect URLs". This can be used to emulate that.
hookValidRedirectURL func(redirectURL string) error
hookUserInfo func(email string) jwt.MapClaims
hookUserInfo func(email string) (jwt.MapClaims, error)
fakeCoderd func(req *http.Request) (*http.Response, error)
hookOnRefresh func(email string) error
// Custom authentication for the client. This is useful if you want
@ -75,6 +76,26 @@ type FakeIDP struct {
serve bool
}
func StatusError(code int, err error) error {
return statusHookError{
Err: err,
HTTPStatusCode: code,
}
}
// statusHookError allows a hook to change the returned http status code.
type statusHookError struct {
Err error
HTTPStatusCode int
}
func (s statusHookError) Error() string {
if s.Err == nil {
return ""
}
return s.Err.Error()
}
type FakeIDPOpt func(idp *FakeIDP)
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
@ -83,9 +104,9 @@ func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeID
}
}
// WithRefreshHook is called when a refresh token is used. The email is
// WithRefresh is called when a refresh token is used. The email is
// the email of the user that is being refreshed assuming the claims are correct.
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
func WithRefresh(hook func(email string) error) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookOnRefresh = hook
}
@ -108,13 +129,13 @@ func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
// every user on the /userinfo endpoint.
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = func(_ string) jwt.MapClaims {
return info
f.hookUserInfo = func(_ string) (jwt.MapClaims, error) {
return info, nil
}
}
}
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
func WithDynamicUserInfo(userInfoFunc func(email string) (jwt.MapClaims, error)) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookUserInfo = userInfoFunc
}
@ -160,7 +181,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
hookOnRefresh: func(_ string) error { return nil },
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
hookValidRedirectURL: func(redirectURL string) error { return nil },
}
@ -181,6 +202,10 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
return idp
}
func (f *FakeIDP) WellknownConfig() ProviderJSON {
return f.provider
}
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
t.Helper()
@ -188,9 +213,9 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
require.NoError(t, err, "invalid issuer URL")
f.issuer = issuer
// providerJSON is the JSON representation of the OpenID Connect provider
// ProviderJSON is the JSON representation of the OpenID Connect provider
// These are all the urls that the IDP will respond to.
f.provider = providerJSON{
f.provider = ProviderJSON{
Issuer: issuer,
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
@ -220,6 +245,15 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
return srv
}
// GenerateAuthenticatedToken skips all oauth2 flows, and just generates a
// valid token for some given claims.
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
state := uuid.NewString()
f.stateToIDTokenClaims.Store(state, claims)
code := f.newCode(state)
return f.cfg.Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
}
// Login does the full OIDC flow starting at the "LoginButton".
// The client argument is just to get the URL of the Coder instance.
//
@ -333,7 +367,8 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
return resp, nil
}
type providerJSON struct {
// ProviderJSON is the .well-known/configuration JSON
type ProviderJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
@ -475,7 +510,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
err := f.hookValidRedirectURL(redirectURI)
if err != nil {
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
@ -501,7 +536,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
slog.F("values", values.Encode()),
)
if err != nil {
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
getEmail := func(claims jwt.MapClaims) string {
@ -562,7 +597,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
claims = idTokenClaims
err := f.hookOnRefresh(getEmail(claims))
if err != nil {
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
@ -610,7 +645,12 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
return
}
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
claims, err := f.hookUserInfo(email)
if err != nil {
http.Error(rw, fmt.Sprintf("user info hook returned error: %s", err.Error()), httpErrorCode(http.StatusBadRequest, err))
return
}
_ = json.NewEncoder(rw).Encode(claims)
}))
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -768,6 +808,15 @@ func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *co
return cfg
}
func httpErrorCode(defaultCode int, err error) int {
var stautsErr statusHookError
status := defaultCode
if errors.As(err, &stautsErr) {
status = stautsErr.HTTPStatusCode
}
return status
}
type fakeRoundTripper struct {
roundTrip func(req *http.Request) (*http.Response, error)
}