mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
test: add full OIDC fake IDP (#9317)
* test: implement fake OIDC provider with full functionality * Refactor existing tests
This commit is contained in:
@ -31,15 +31,13 @@ import (
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/fullsailor/pkcs7"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"google.golang.org/api/option"
|
||||
@ -1020,152 +1018,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
|
||||
}
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
key *rsa.PrivateKey
|
||||
issuer string
|
||||
// These are optional
|
||||
refreshToken string
|
||||
oidcTokenExpires func() time.Time
|
||||
tokenSource func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.refreshToken = token
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.oidcTokenExpires = expFunc
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.tokenSource = src
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
if issuer == "" {
|
||||
issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
cfg := &OIDCConfig{
|
||||
key: pkey,
|
||||
issuer: issuer,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
type tokenSource struct {
|
||||
src func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (s tokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.src()
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
if cfg.tokenSource == nil {
|
||||
return nil
|
||||
}
|
||||
return tokenSource{
|
||||
src: cfg.tokenSource,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
token, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode code: %w", err)
|
||||
}
|
||||
|
||||
var exp time.Time
|
||||
if cfg.oidcTokenExpires != nil {
|
||||
exp = cfg.oidcTokenExpires()
|
||||
}
|
||||
|
||||
return (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
RefreshToken: cfg.refreshToken,
|
||||
Expiry: exp,
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": string(token),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = cfg.issuer
|
||||
}
|
||||
|
||||
if _, ok := claims["sub"]; !ok {
|
||||
claims["sub"] = "testme"
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(signed))
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
// By default, the provider can be empty.
|
||||
// This means it won't support any endpoints!
|
||||
provider := &oidc.Provider{}
|
||||
if userInfoClaims != nil {
|
||||
resp, err := json.Marshal(userInfoClaims)
|
||||
require.NoError(t, err)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(resp)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
cfg := &oidc.ProviderConfig{
|
||||
UserInfoURL: srv.URL,
|
||||
}
|
||||
provider = cfg.NewProvider(context.Background())
|
||||
}
|
||||
newCFG := &coderd.OIDCConfig{
|
||||
OAuth2Config: cfg,
|
||||
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
}),
|
||||
Provider: provider,
|
||||
UsernameField: "preferred_username",
|
||||
EmailField: "email",
|
||||
AuthURLParams: map[string]string{"access_type": "offline"},
|
||||
GroupField: "groups",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(newCFG)
|
||||
}
|
||||
return newCFG
|
||||
}
|
||||
|
||||
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
||||
// instance authentication for Azure.
|
||||
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
|
||||
@ -1254,22 +1106,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
|
||||
return cerr
|
||||
}
|
||||
|
||||
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
||||
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
||||
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
||||
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
||||
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
||||
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
||||
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
||||
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
||||
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
||||
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
||||
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
||||
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
||||
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
func DeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
||||
var cfg codersdk.DeploymentValues
|
||||
opts := cfg.Options()
|
||||
|
103
coderd/coderdtest/oidctest/helper.go
Normal file
103
coderd/coderdtest/oidctest/helper.go
Normal file
@ -0,0 +1,103 @@
|
||||
package oidctest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// LoginHelper helps with logging in a user and refreshing their oauth tokens.
|
||||
// It is mainly because refreshing oauth tokens is a bit tricky and requires
|
||||
// some database manipulation.
|
||||
type LoginHelper struct {
|
||||
fake *FakeIDP
|
||||
client *codersdk.Client
|
||||
}
|
||||
|
||||
func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper {
|
||||
if client == nil {
|
||||
panic("client must not be nil")
|
||||
}
|
||||
if fake == nil {
|
||||
panic("fake must not be nil")
|
||||
}
|
||||
return &LoginHelper{
|
||||
fake: fake,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Login just helps by making an unauthenticated client and logging in with
|
||||
// the given claims. All Logins should be unauthenticated, so this is a
|
||||
// convenience method.
|
||||
func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
unauthenticatedClient := codersdk.New(h.client.URL)
|
||||
|
||||
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
|
||||
}
|
||||
|
||||
// ExpireOauthToken expires the oauth token for the given user.
|
||||
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink {
|
||||
t.Helper()
|
||||
|
||||
//nolint:gocritic // Testing
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
|
||||
id, _, err := httpmw.SplitAPIToken(user.SessionToken())
|
||||
require.NoError(t, err)
|
||||
|
||||
// We need to get the OIDC link and update it in the database to force
|
||||
// it to be expired.
|
||||
key, err := db.GetAPIKeyByID(ctx, id)
|
||||
require.NoError(t, err, "get api key")
|
||||
|
||||
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
})
|
||||
require.NoError(t, err, "get user link")
|
||||
|
||||
// Expire the oauth link for the given user.
|
||||
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
OAuthRefreshToken: link.OAuthRefreshToken,
|
||||
OAuthExpiry: time.Now().Add(time.Hour * -1),
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
})
|
||||
require.NoError(t, err, "expire user link")
|
||||
|
||||
return updated
|
||||
}
|
||||
|
||||
// ForceRefresh forces the client to refresh its oauth token. It does this by
|
||||
// expiring the oauth token, then doing an authenticated call. This will force
|
||||
// the API Key middleware to refresh the oauth token.
|
||||
//
|
||||
// A unit test assertion makes sure the refresh token is used.
|
||||
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
|
||||
t.Helper()
|
||||
|
||||
link := h.ExpireOauthToken(t, db, user)
|
||||
// Updates the claims that the IDP will return. By default, it always
|
||||
// uses the original claims for the original oauth token.
|
||||
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
|
||||
})
|
||||
|
||||
// Do any authenticated call to force the refresh
|
||||
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
|
||||
require.NoError(t, err, "user must be able to be fetched")
|
||||
}
|
793
coderd/coderdtest/oidctest/idp.go
Normal file
793
coderd/coderdtest/oidctest/idp.go
Normal file
@ -0,0 +1,793 @@
|
||||
package oidctest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/util/syncmap"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// FakeIDP is a functional OIDC provider.
|
||||
// It only supports 1 OIDC client.
|
||||
type FakeIDP struct {
|
||||
issuer string
|
||||
key *rsa.PrivateKey
|
||||
provider providerJSON
|
||||
handler http.Handler
|
||||
cfg *oauth2.Config
|
||||
|
||||
// clientID to be used by coderd
|
||||
clientID string
|
||||
clientSecret string
|
||||
logger slog.Logger
|
||||
|
||||
// These maps are used to control the state of the IDP.
|
||||
// That is the various access tokens, refresh tokens, states, etc.
|
||||
codeToStateMap *syncmap.Map[string, string]
|
||||
// Token -> Email
|
||||
accessTokens *syncmap.Map[string, string]
|
||||
// Refresh Token -> Email
|
||||
refreshTokensUsed *syncmap.Map[string, bool]
|
||||
refreshTokens *syncmap.Map[string, string]
|
||||
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||
|
||||
// hooks
|
||||
// hookValidRedirectURL can be used to reject a redirect url from the
|
||||
// IDP -> Application. Almost all IDPs have the concept of
|
||||
// "Authorized Redirect URLs". This can be used to emulate that.
|
||||
hookValidRedirectURL func(redirectURL string) error
|
||||
hookUserInfo func(email string) jwt.MapClaims
|
||||
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||
hookOnRefresh func(email string) error
|
||||
// Custom authentication for the client. This is useful if you want
|
||||
// to test something like PKI auth vs a client_secret.
|
||||
hookAuthenticateClient func(t testing.TB, req *http.Request) (url.Values, error)
|
||||
serve bool
|
||||
}
|
||||
|
||||
type FakeIDPOpt func(idp *FakeIDP)
|
||||
|
||||
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookValidRedirectURL = hook
|
||||
}
|
||||
}
|
||||
|
||||
// WithRefreshHook is called when a refresh token is used. The email is
|
||||
// the email of the user that is being refreshed assuming the claims are correct.
|
||||
func WithRefreshHook(hook func(email string) error) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookOnRefresh = hook
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookAuthenticateClient = hook
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogging is optional, but will log some HTTP calls made to the IDP.
|
||||
func WithLogging(t testing.TB, options *slogtest.Options) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.logger = slogtest.Make(t, options)
|
||||
}
|
||||
}
|
||||
|
||||
// WithStaticUserInfo is optional, but will return the same user info for
|
||||
// every user on the /userinfo endpoint.
|
||||
func WithStaticUserInfo(info jwt.MapClaims) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = func(_ string) jwt.MapClaims {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithDynamicUserInfo(userInfoFunc func(email string) jwt.MapClaims) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookUserInfo = userInfoFunc
|
||||
}
|
||||
}
|
||||
|
||||
// WithServing makes the IDP run an actual http server.
|
||||
func WithServing() func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.serve = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithIssuer(issuer string) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.issuer = issuer
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// nolint:gosec // It thinks this is a secret lol
|
||||
tokenPath = "/oauth2/token"
|
||||
authorizePath = "/oauth2/authorize"
|
||||
keysPath = "/oauth2/keys"
|
||||
userInfoPath = "/oauth2/userinfo"
|
||||
)
|
||||
|
||||
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
idp := &FakeIDP{
|
||||
key: pkey,
|
||||
clientID: uuid.NewString(),
|
||||
clientSecret: uuid.NewString(),
|
||||
logger: slog.Make(),
|
||||
codeToStateMap: syncmap.New[string, string](),
|
||||
accessTokens: syncmap.New[string, string](),
|
||||
refreshTokens: syncmap.New[string, string](),
|
||||
refreshTokensUsed: syncmap.New[string, bool](),
|
||||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
hookOnRefresh: func(_ string) error { return nil },
|
||||
hookUserInfo: func(email string) jwt.MapClaims { return jwt.MapClaims{} },
|
||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(idp)
|
||||
}
|
||||
|
||||
if idp.issuer == "" {
|
||||
idp.issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
idp.handler = idp.httpHandler(t)
|
||||
idp.updateIssuerURL(t, idp.issuer)
|
||||
if idp.serve {
|
||||
idp.realServer(t)
|
||||
}
|
||||
|
||||
return idp
|
||||
}
|
||||
|
||||
func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
t.Helper()
|
||||
|
||||
u, err := url.Parse(issuer)
|
||||
require.NoError(t, err, "invalid issuer URL")
|
||||
|
||||
f.issuer = issuer
|
||||
// providerJSON is the JSON representation of the OpenID Connect provider
|
||||
// These are all the urls that the IDP will respond to.
|
||||
f.provider = providerJSON{
|
||||
Issuer: issuer,
|
||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
|
||||
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
|
||||
Algorithms: []string{
|
||||
"RS256",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// realServer turns the FakeIDP into a real http server.
|
||||
func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
srv := httptest.NewUnstartedServer(f.handler)
|
||||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||
return ctx
|
||||
}
|
||||
srv.Start()
|
||||
t.Cleanup(srv.CloseClientConnections)
|
||||
t.Cleanup(srv.Close)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
f.updateIssuerURL(t, srv.URL)
|
||||
return srv
|
||||
}
|
||||
|
||||
// Login does the full OIDC flow starting at the "LoginButton".
|
||||
// The client argument is just to get the URL of the Coder instance.
|
||||
//
|
||||
// The client passed in is just to get the url of the Coder instance.
|
||||
// The actual client that is used is 100% unauthenticated and fresh.
|
||||
func (f *FakeIDP) Login(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
|
||||
client, resp := f.AttemptLogin(t, client, idTokenClaims, opts...)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode, "client failed to login")
|
||||
return client, resp
|
||||
}
|
||||
|
||||
func (f *FakeIDP) AttemptLogin(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
var err error
|
||||
|
||||
cli := f.HTTPClient(client.HTTPClient)
|
||||
shallowCpyCli := *cli
|
||||
|
||||
if shallowCpyCli.Jar == nil {
|
||||
shallowCpyCli.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err, "failed to create cookie jar")
|
||||
}
|
||||
|
||||
unauthenticated := codersdk.New(client.URL)
|
||||
unauthenticated.HTTPClient = &shallowCpyCli
|
||||
|
||||
return f.LoginWithClient(t, unauthenticated, idTokenClaims, opts...)
|
||||
}
|
||||
|
||||
// LoginWithClient reuses the context of the passed in client. This means the same
|
||||
// cookies will be used. This should be an unauthenticated client in most cases.
|
||||
//
|
||||
// This is a niche case, but it is needed for testing ConvertLoginType.
|
||||
func (f *FakeIDP) LoginWithClient(t testing.TB, client *codersdk.Client, idTokenClaims jwt.MapClaims, opts ...func(r *http.Request)) (*codersdk.Client, *http.Response) {
|
||||
t.Helper()
|
||||
|
||||
coderOauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback")
|
||||
require.NoError(t, err)
|
||||
f.SetRedirect(t, coderOauthURL.String())
|
||||
|
||||
cli := f.HTTPClient(client.HTTPClient)
|
||||
cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
// Store the idTokenClaims to the specific state request. This ties
|
||||
// the claims 1:1 with a given authentication flow.
|
||||
state := req.URL.Query().Get("state")
|
||||
f.stateToIDTokenClaims.Store(state, idTokenClaims)
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", coderOauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
if cli.Jar == nil {
|
||||
cli.Jar, err = cookiejar.New(nil)
|
||||
require.NoError(t, err, "failed to create cookie jar")
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(req)
|
||||
}
|
||||
|
||||
res, err := cli.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// If the coder session token exists, return the new authed client!
|
||||
var user *codersdk.Client
|
||||
cookies := cli.Jar.Cookies(client.URL)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == codersdk.SessionTokenCookie {
|
||||
user = codersdk.New(client.URL)
|
||||
user.SetSessionToken(cookie.Value)
|
||||
}
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
if res.Body != nil {
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return user, res
|
||||
}
|
||||
|
||||
// OIDCCallback will emulate the IDP redirecting back to the Coder callback.
|
||||
// This is helpful if no Coderd exists because the IDP needs to redirect to
|
||||
// something.
|
||||
// Essentially this is used to fake the Coderd side of the exchange.
|
||||
// The flow starts at the user hitting the OIDC login page.
|
||||
func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.MapClaims) (*http.Response, error) {
|
||||
t.Helper()
|
||||
if f.serve {
|
||||
panic("cannot use OIDCCallback with WithServing. This is only for the in memory usage")
|
||||
}
|
||||
|
||||
f.stateToIDTokenClaims.Store(state, idTokenClaims)
|
||||
|
||||
cli := f.HTTPClient(nil)
|
||||
u := f.cfg.AuthCodeURL(state)
|
||||
req, err := http.NewRequest("GET", u, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := cli.Do(req.WithContext(context.Background()))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
})
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
type providerJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||
}
|
||||
|
||||
// newCode enforces the code exchanged is actually a valid code
|
||||
// created by the IDP.
|
||||
func (f *FakeIDP) newCode(state string) string {
|
||||
code := uuid.NewString()
|
||||
f.codeToStateMap.Store(code, state)
|
||||
return code
|
||||
}
|
||||
|
||||
// newToken enforces the access token exchanged is actually a valid access token
|
||||
// created by the IDP.
|
||||
func (f *FakeIDP) newToken(email string) string {
|
||||
accessToken := uuid.NewString()
|
||||
f.accessTokens.Store(accessToken, email)
|
||||
return accessToken
|
||||
}
|
||||
|
||||
func (f *FakeIDP) newRefreshTokens(email string) string {
|
||||
refreshToken := uuid.NewString()
|
||||
f.refreshTokens.Store(refreshToken, email)
|
||||
return refreshToken
|
||||
}
|
||||
|
||||
// authenticateBearerTokenRequest enforces the access token is valid.
|
||||
func (f *FakeIDP) authenticateBearerTokenRequest(t testing.TB, req *http.Request) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
auth := req.Header.Get("Authorization")
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
_, ok := f.accessTokens.Load(token)
|
||||
if !ok {
|
||||
return "", xerrors.New("invalid access token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// authenticateOIDCClientRequest enforces the client_id and client_secret are valid.
|
||||
func (f *FakeIDP) authenticateOIDCClientRequest(t testing.TB, req *http.Request) (url.Values, error) {
|
||||
t.Helper()
|
||||
|
||||
if f.hookAuthenticateClient != nil {
|
||||
return f.hookAuthenticateClient(t, req)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(req.Body)
|
||||
if !assert.NoError(t, err, "read token request body") {
|
||||
return nil, xerrors.Errorf("authenticate request, read body: %w", err)
|
||||
}
|
||||
values, err := url.ParseQuery(string(data))
|
||||
if !assert.NoError(t, err, "parse token request values") {
|
||||
return nil, xerrors.New("invalid token request")
|
||||
}
|
||||
|
||||
if !assert.Equal(t, f.clientID, values.Get("client_id"), "client_id mismatch") {
|
||||
return nil, xerrors.New("client_id mismatch")
|
||||
}
|
||||
|
||||
if !assert.Equal(t, f.clientSecret, values.Get("client_secret"), "client_secret mismatch") {
|
||||
return nil, xerrors.New("client_secret mismatch")
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// encodeClaims is a helper func to convert claims to a valid JWT.
|
||||
func (f *FakeIDP) encodeClaims(t testing.TB, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
if _, ok := claims["aud"]; !ok {
|
||||
claims["aud"] = f.clientID
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = f.issuer
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(f.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return signed
|
||||
}
|
||||
|
||||
// httpHandler is the IDP http server.
|
||||
func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
mux := chi.NewMux()
|
||||
// This endpoint is required to initialize the OIDC provider.
|
||||
// It is used to get the OIDC configuration.
|
||||
mux.Get("/.well-known/openid-configuration", func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http OIDC config", slog.F("url", r.URL.String()))
|
||||
|
||||
_ = json.NewEncoder(rw).Encode(f.provider)
|
||||
})
|
||||
|
||||
// Authorize is called when the user is redirected to the IDP to login.
|
||||
// This is the browser hitting the IDP and the user logging into Google or
|
||||
// w/e and clicking "Allow". They will be redirected back to the redirect
|
||||
// when this is done.
|
||||
mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http call authorize", slog.F("url", r.URL.String()))
|
||||
|
||||
clientID := r.URL.Query().Get("client_id")
|
||||
if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") {
|
||||
http.Error(rw, "invalid client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := r.URL.Query().Get("redirect_uri")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
scope := r.URL.Query().Get("scope")
|
||||
assert.NotEmpty(t, scope, "scope is empty")
|
||||
|
||||
responseType := r.URL.Query().Get("response_type")
|
||||
switch responseType {
|
||||
case "code":
|
||||
case "token":
|
||||
t.Errorf("response_type %q not supported", responseType)
|
||||
http.Error(rw, "invalid response_type", http.StatusBadRequest)
|
||||
return
|
||||
default:
|
||||
t.Errorf("unexpected response_type %q", responseType)
|
||||
http.Error(rw, "invalid response_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err := f.hookValidRedirectURL(redirectURI)
|
||||
if err != nil {
|
||||
t.Errorf("not authorized redirect_uri by custom hook %q: %s", redirectURI, err.Error())
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ru, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
t.Errorf("invalid redirect_uri %q: %s", redirectURI, err.Error())
|
||||
http.Error(rw, fmt.Sprintf("invalid redirect_uri: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
q := ru.Query()
|
||||
q.Set("state", state)
|
||||
q.Set("code", f.newCode(state))
|
||||
ru.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect)
|
||||
}))
|
||||
|
||||
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
values, err := f.authenticateOIDCClientRequest(t, r)
|
||||
f.logger.Info(r.Context(), "http idp call token",
|
||||
slog.Error(err),
|
||||
slog.F("values", values.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("invalid token request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
getEmail := func(claims jwt.MapClaims) string {
|
||||
email, ok := claims["email"]
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
emailStr, ok := email.(string)
|
||||
if !ok {
|
||||
return "wrong-type"
|
||||
}
|
||||
return emailStr
|
||||
}
|
||||
|
||||
var claims jwt.MapClaims
|
||||
switch values.Get("grant_type") {
|
||||
case "authorization_code":
|
||||
code := values.Get("code")
|
||||
if !assert.NotEmpty(t, code, "code is empty") {
|
||||
http.Error(rw, "invalid code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
stateStr, ok := f.codeToStateMap.Load(code)
|
||||
if !assert.True(t, ok, "invalid code") {
|
||||
http.Error(rw, "invalid code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// Always invalidate the code after it is used.
|
||||
f.codeToStateMap.Delete(code)
|
||||
|
||||
idTokenClaims, ok := f.stateToIDTokenClaims.Load(stateStr)
|
||||
if !ok {
|
||||
t.Errorf("missing id token claims")
|
||||
http.Error(rw, "missing id token claims", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
claims = idTokenClaims
|
||||
case "refresh_token":
|
||||
refreshToken := values.Get("refresh_token")
|
||||
if !assert.NotEmpty(t, refreshToken, "refresh_token is empty") {
|
||||
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := f.refreshTokens.Load(refreshToken)
|
||||
if !assert.True(t, ok, "invalid refresh_token") {
|
||||
http.Error(rw, "invalid refresh_token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idTokenClaims, ok := f.refreshIDTokenClaims.Load(refreshToken)
|
||||
if !ok {
|
||||
t.Errorf("missing id token claims in refresh")
|
||||
http.Error(rw, "missing id token claims in refresh", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
claims = idTokenClaims
|
||||
err := f.hookOnRefresh(getEmail(claims))
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("refresh hook blocked refresh: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
f.refreshTokensUsed.Store(refreshToken, true)
|
||||
// Always invalidate the refresh token after it is used.
|
||||
f.refreshTokens.Delete(refreshToken)
|
||||
default:
|
||||
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
|
||||
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
exp := time.Now().Add(time.Minute * 5)
|
||||
claims["exp"] = exp.UnixMilli()
|
||||
email := getEmail(claims)
|
||||
refreshToken := f.newRefreshTokens(email)
|
||||
token := map[string]interface{}{
|
||||
"access_token": f.newToken(email),
|
||||
"refresh_token": refreshToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": int64((time.Minute * 5).Seconds()),
|
||||
"id_token": f.encodeClaims(t, claims),
|
||||
}
|
||||
// Store the claims for the next refresh
|
||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(rw).Encode(token)
|
||||
}))
|
||||
|
||||
mux.Handle(userInfoPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
token, err := f.authenticateBearerTokenRequest(t, r)
|
||||
f.logger.Info(r.Context(), "http call idp user info",
|
||||
slog.Error(err),
|
||||
slog.F("url", r.URL.String()),
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(rw, fmt.Sprintf("invalid user info request: %s", err.Error()), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
email, ok := f.accessTokens.Load(token)
|
||||
if !ok {
|
||||
t.Errorf("access token user for user_info has no email to indicate which user")
|
||||
http.Error(rw, "invalid access token, missing user info", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(f.hookUserInfo(email))
|
||||
}))
|
||||
|
||||
mux.Handle(keysPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http call idp /keys")
|
||||
set := jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: f.key.Public(),
|
||||
KeyID: "test-key",
|
||||
Algorithm: "RSA",
|
||||
},
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(rw).Encode(set)
|
||||
}))
|
||||
|
||||
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
|
||||
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
|
||||
})
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// HTTPClient does nothing if IsServing is used.
|
||||
//
|
||||
// If IsServing is not used, then it will return a client that will make requests
|
||||
// to the IDP all in memory. If a request is not to the IDP, then the passed in
|
||||
// client will be used. If no client is passed in, then any regular network
|
||||
// requests will fail.
|
||||
func (f *FakeIDP) HTTPClient(rest *http.Client) *http.Client {
|
||||
if f.serve {
|
||||
if rest == nil || rest.Transport == nil {
|
||||
return &http.Client{}
|
||||
}
|
||||
return rest
|
||||
}
|
||||
|
||||
var jar http.CookieJar
|
||||
if rest != nil {
|
||||
jar = rest.Jar
|
||||
}
|
||||
return &http.Client{
|
||||
Jar: jar,
|
||||
Transport: fakeRoundTripper{
|
||||
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
u, _ := url.Parse(f.issuer)
|
||||
if req.URL.Host != u.Host {
|
||||
if f.fakeCoderd != nil {
|
||||
return f.fakeCoderd(req)
|
||||
}
|
||||
if rest == nil || rest.Transport == nil {
|
||||
return nil, fmt.Errorf("unexpected network request to %q", req.URL.Host)
|
||||
}
|
||||
return rest.Transport.RoundTrip(req)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
f.handler.ServeHTTP(resp, req)
|
||||
return resp.Result(), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshUsed returns if the refresh token has been used. All refresh tokens
|
||||
// can only be used once, then they are deleted.
|
||||
func (f *FakeIDP) RefreshUsed(refreshToken string) bool {
|
||||
used, _ := f.refreshTokensUsed.Load(refreshToken)
|
||||
return used
|
||||
}
|
||||
|
||||
// UpdateRefreshClaims allows the caller to change what claims are returned
|
||||
// for a given refresh token. By default, all refreshes use the same claims as
|
||||
// the original IDToken issuance.
|
||||
func (f *FakeIDP) UpdateRefreshClaims(refreshToken string, claims jwt.MapClaims) {
|
||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||
}
|
||||
|
||||
// SetRedirect is required for the IDP to know where to redirect and call
|
||||
// Coderd.
|
||||
func (f *FakeIDP) SetRedirect(t testing.TB, u string) {
|
||||
t.Helper()
|
||||
|
||||
f.cfg.RedirectURL = u
|
||||
}
|
||||
|
||||
// SetCoderdCallback is optional and only works if not using the IsServing.
|
||||
// It will setup a fake "Coderd" for the IDP to call when the IDP redirects
|
||||
// back after authenticating.
|
||||
func (f *FakeIDP) SetCoderdCallback(callback func(req *http.Request) (*http.Response, error)) {
|
||||
if f.serve {
|
||||
panic("cannot set callback handler when using 'WithServing'. Must implement an actual 'Coderd'")
|
||||
}
|
||||
f.fakeCoderd = callback
|
||||
}
|
||||
|
||||
func (f *FakeIDP) SetCoderdCallbackHandler(handler http.HandlerFunc) {
|
||||
f.SetCoderdCallback(func(req *http.Request) (*http.Response, error) {
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
return resp.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
// OIDCConfig returns the OIDC config to use for Coderd.
|
||||
func (f *FakeIDP) OIDCConfig(t testing.TB, scopes []string, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
t.Helper()
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "email", "profile"}
|
||||
}
|
||||
|
||||
oauthCfg := &oauth2.Config{
|
||||
ClientID: f.clientID,
|
||||
ClientSecret: f.clientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: f.provider.AuthURL,
|
||||
TokenURL: f.provider.TokenURL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
// If the user is using a real network request, they will need to do
|
||||
// 'fake.SetRedirect()'
|
||||
RedirectURL: "https://redirect.com",
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), f.HTTPClient(nil))
|
||||
p, err := oidc.NewProvider(ctx, f.provider.Issuer)
|
||||
require.NoError(t, err, "failed to create OIDC provider")
|
||||
cfg := &coderd.OIDCConfig{
|
||||
OAuth2Config: oauthCfg,
|
||||
Provider: p,
|
||||
Verifier: oidc.NewVerifier(f.provider.Issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{f.key.Public()},
|
||||
}, &oidc.Config{
|
||||
ClientID: oauthCfg.ClientID,
|
||||
SupportedSigningAlgs: []string{
|
||||
"RS256",
|
||||
},
|
||||
// Todo: add support for Now()
|
||||
}),
|
||||
UsernameField: "preferred_username",
|
||||
EmailField: "email",
|
||||
AuthURLParams: map[string]string{"access_type": "offline"},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt == nil {
|
||||
continue
|
||||
}
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
f.cfg = oauthCfg
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct {
|
||||
roundTrip func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f.roundTrip(req)
|
||||
}
|
||||
|
||||
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
||||
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
||||
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
||||
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
||||
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
||||
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
||||
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
||||
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
||||
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
||||
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
||||
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
||||
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
||||
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
||||
-----END RSA PRIVATE KEY-----`
|
72
coderd/coderdtest/oidctest/idp_test.go
Normal file
72
coderd/coderdtest/oidctest/idp_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package oidctest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// TestFakeIDPBasicFlow tests the basic flow of the fake IDP.
|
||||
// It is done all in memory with no actual network requests.
|
||||
// nolint:bodyclose
|
||||
func TestFakeIDPBasicFlow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithLogging(t, nil),
|
||||
)
|
||||
|
||||
var handler http.Handler
|
||||
srv := httptest.NewServer(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler.ServeHTTP(w, r)
|
||||
})))
|
||||
defer srv.Close()
|
||||
|
||||
cfg := fake.OIDCConfig(t, nil)
|
||||
cli := fake.HTTPClient(nil)
|
||||
ctx := oidc.ClientContext(context.Background(), cli)
|
||||
|
||||
const expectedState = "random-state"
|
||||
var token *oauth2.Token
|
||||
// This is the Coder callback using an actual network request.
|
||||
fake.SetCoderdCallbackHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Emulate OIDC flow
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
assert.Equal(t, expectedState, state, "state mismatch")
|
||||
|
||||
oauthToken, err := cfg.Exchange(ctx, code)
|
||||
if assert.NoError(t, err, "failed to exchange code") {
|
||||
assert.NotEmpty(t, oauthToken.AccessToken, "access token is empty")
|
||||
assert.NotEmpty(t, oauthToken.RefreshToken, "refresh token is empty")
|
||||
}
|
||||
token = oauthToken
|
||||
})
|
||||
|
||||
resp, err := fake.OIDCCallback(t, expectedState, jwt.MapClaims{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Test the user info
|
||||
_, err = cfg.Provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now test it can refresh
|
||||
refreshed, err := cfg.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
Expiry: time.Now().Add(time.Minute * -1),
|
||||
}).Token()
|
||||
require.NoError(t, err, "failed to refresh token")
|
||||
require.NotEmpty(t, refreshed.AccessToken, "access token is empty on refresh")
|
||||
}
|
Reference in New Issue
Block a user