fix: remove refresh oauth logic on OIDC login (#8950)

* fix: do not do oauth refresh logic on oidc login
This commit is contained in:
Steven Masley
2023-08-08 10:05:12 -05:00
committed by GitHub
parent 1d4a72f43f
commit 5339a31532
6 changed files with 217 additions and 68 deletions

View File

@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
type OIDCConfig struct {
key *rsa.PrivateKey
issuer string
// These are optional
refreshToken string
oidcTokenExpires func() time.Time
tokenSource func() (*oauth2.Token, error)
}
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.refreshToken = token
}
}
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.oidcTokenExpires = expFunc
}
}
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
return func(cfg *OIDCConfig) {
cfg.tokenSource = src
}
}
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
t.Helper()
block, _ := pem.Decode([]byte(testRSAPrivateKey))
@ -1035,33 +1057,58 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
issuer = "https://coder.com"
}
return &OIDCConfig{
cfg := &OIDCConfig{
key: pkey,
issuer: issuer,
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return nil
type tokenSource struct {
src func() (*oauth2.Token, error)
}
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
func (s tokenSource) Token() (*oauth2.Token, error) {
return s.src()
}
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
if cfg.tokenSource == nil {
return nil
}
return tokenSource{
src: cfg.tokenSource,
}
}
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
token, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, xerrors.Errorf("decode code: %w", err)
}
var exp time.Time
if cfg.oidcTokenExpires != nil {
exp = cfg.oidcTokenExpires()
}
return (&oauth2.Token{
AccessToken: "token",
AccessToken: "token",
RefreshToken: cfg.refreshToken,
Expiry: exp,
}).WithExtra(map[string]interface{}{
"id_token": string(token),
}), nil
}
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
t.Helper()
if _, ok := claims["exp"]; !ok {
@ -1069,20 +1116,20 @@ func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
}
if _, ok := claims["iss"]; !ok {
claims["iss"] = o.issuer
claims["iss"] = cfg.issuer
}
if _, ok := claims["sub"]; !ok {
claims["sub"] = "testme"
}
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString([]byte(signed))
}
func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
// By default, the provider can be empty.
// This means it won't support any endpoints!
provider := &oidc.Provider{}
@ -1099,10 +1146,10 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
}
provider = cfg.NewProvider(context.Background())
}
cfg := &coderd.OIDCConfig{
OAuth2Config: o,
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{o.key.Public()},
newCFG := &coderd.OIDCConfig{
OAuth2Config: cfg,
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
}, &oidc.Config{
SkipClientIDCheck: true,
}),
@ -1113,9 +1160,9 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
GroupField: "groups",
}
for _, opt := range opts {
opt(cfg)
opt(newCFG)
}
return cfg
return newCFG
}
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking