mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
test: add full OIDC fake IDP (#9317)
* test: implement fake OIDC provider with full functionality * Refactor existing tests
This commit is contained in:
@ -31,15 +31,13 @@ import (
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/fullsailor/pkcs7"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"google.golang.org/api/option"
|
||||
@ -1020,152 +1018,6 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
|
||||
}
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
key *rsa.PrivateKey
|
||||
issuer string
|
||||
// These are optional
|
||||
refreshToken string
|
||||
oidcTokenExpires func() time.Time
|
||||
tokenSource func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.refreshToken = token
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.oidcTokenExpires = expFunc
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.tokenSource = src
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
if issuer == "" {
|
||||
issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
cfg := &OIDCConfig{
|
||||
key: pkey,
|
||||
issuer: issuer,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
type tokenSource struct {
|
||||
src func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (s tokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.src()
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
if cfg.tokenSource == nil {
|
||||
return nil
|
||||
}
|
||||
return tokenSource{
|
||||
src: cfg.tokenSource,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
token, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode code: %w", err)
|
||||
}
|
||||
|
||||
var exp time.Time
|
||||
if cfg.oidcTokenExpires != nil {
|
||||
exp = cfg.oidcTokenExpires()
|
||||
}
|
||||
|
||||
return (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
RefreshToken: cfg.refreshToken,
|
||||
Expiry: exp,
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": string(token),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = cfg.issuer
|
||||
}
|
||||
|
||||
if _, ok := claims["sub"]; !ok {
|
||||
claims["sub"] = "testme"
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(signed))
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
// By default, the provider can be empty.
|
||||
// This means it won't support any endpoints!
|
||||
provider := &oidc.Provider{}
|
||||
if userInfoClaims != nil {
|
||||
resp, err := json.Marshal(userInfoClaims)
|
||||
require.NoError(t, err)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(resp)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
cfg := &oidc.ProviderConfig{
|
||||
UserInfoURL: srv.URL,
|
||||
}
|
||||
provider = cfg.NewProvider(context.Background())
|
||||
}
|
||||
newCFG := &coderd.OIDCConfig{
|
||||
OAuth2Config: cfg,
|
||||
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
}),
|
||||
Provider: provider,
|
||||
UsernameField: "preferred_username",
|
||||
EmailField: "email",
|
||||
AuthURLParams: map[string]string{"access_type": "offline"},
|
||||
GroupField: "groups",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(newCFG)
|
||||
}
|
||||
return newCFG
|
||||
}
|
||||
|
||||
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
||||
// instance authentication for Azure.
|
||||
func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) {
|
||||
@ -1254,22 +1106,6 @@ func SDKError(t *testing.T, err error) *codersdk.Error {
|
||||
return cerr
|
||||
}
|
||||
|
||||
const testRSAPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXQIBAAKBgQDLets8+7M+iAQAqN/5BVyCIjhTQ4cmXulL+gm3v0oGMWzLupUS
|
||||
v8KPA+Tp7dgC/DZPfMLaNH1obBBhJ9DhS6RdS3AS3kzeFrdu8zFHLWF53DUBhS92
|
||||
5dCAEuJpDnNizdEhxTfoHrhuCmz8l2nt1pe5eUK2XWgd08Uc93h5ij098wIDAQAB
|
||||
AoGAHLaZeWGLSaen6O/rqxg2laZ+jEFbMO7zvOTruiIkL/uJfrY1kw+8RLIn+1q0
|
||||
wLcWcuEIHgKKL9IP/aXAtAoYh1FBvRPLkovF1NZB0Je/+CSGka6wvc3TGdvppZJe
|
||||
rKNcUvuOYLxkmLy4g9zuY5qrxFyhtIn2qZzXEtLaVOHzPQECQQDvN0mSajpU7dTB
|
||||
w4jwx7IRXGSSx65c+AsHSc1Rj++9qtPC6WsFgAfFN2CEmqhMbEUVGPv/aPjdyWk9
|
||||
pyLE9xR/AkEA2cGwyIunijE5v2rlZAD7C4vRgdcMyCf3uuPcgzFtsR6ZhyQSgLZ8
|
||||
YRPuvwm4cdPJMmO3YwBfxT6XGuSc2k8MjQJBAI0+b8prvpV2+DCQa8L/pjxp+VhR
|
||||
Xrq2GozrHrgR7NRokTB88hwFRJFF6U9iogy9wOx8HA7qxEbwLZuhm/4AhbECQC2a
|
||||
d8h4Ht09E+f3nhTEc87mODkl7WJZpHL6V2sORfeq/eIkds+H6CJ4hy5w/bSw8tjf
|
||||
sz9Di8sGIaUbLZI2rd0CQQCzlVwEtRtoNCyMJTTrkgUuNufLP19RZ5FpyXxBO5/u
|
||||
QastnN77KfUwdj3SJt44U/uh1jAIv4oSLBr8HYUkbnI8
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
func DeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
||||
var cfg codersdk.DeploymentValues
|
||||
opts := cfg.Options()
|
||||
|
Reference in New Issue
Block a user