mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
fix: remove refresh oauth logic on OIDC login (#8950)
* fix: do not do oauth refresh logic on oidc login
This commit is contained in:
@ -1022,9 +1022,31 @@ func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certif
|
||||
type OIDCConfig struct {
|
||||
key *rsa.PrivateKey
|
||||
issuer string
|
||||
// These are optional
|
||||
refreshToken string
|
||||
oidcTokenExpires func() time.Time
|
||||
tokenSource func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
|
||||
func WithRefreshToken(token string) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.refreshToken = token
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenExpires(expFunc func() time.Time) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.oidcTokenExpires = expFunc
|
||||
}
|
||||
}
|
||||
|
||||
func WithTokenSource(src func() (*oauth2.Token, error)) func(cfg *OIDCConfig) {
|
||||
return func(cfg *OIDCConfig) {
|
||||
cfg.tokenSource = src
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCConfig(t *testing.T, issuer string, opts ...func(cfg *OIDCConfig)) *OIDCConfig {
|
||||
t.Helper()
|
||||
|
||||
block, _ := pem.Decode([]byte(testRSAPrivateKey))
|
||||
@ -1035,33 +1057,58 @@ func NewOIDCConfig(t *testing.T, issuer string) *OIDCConfig {
|
||||
issuer = "https://coder.com"
|
||||
}
|
||||
|
||||
return &OIDCConfig{
|
||||
cfg := &OIDCConfig{
|
||||
key: pkey,
|
||||
issuer: issuer,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (*OIDCConfig) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return nil
|
||||
type tokenSource struct {
|
||||
src func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (*OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
func (s tokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.src()
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
if cfg.tokenSource == nil {
|
||||
return nil
|
||||
}
|
||||
return tokenSource{
|
||||
src: cfg.tokenSource,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *OIDCConfig) Exchange(_ context.Context, code string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
token, err := base64.StdEncoding.DecodeString(code)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode code: %w", err)
|
||||
}
|
||||
|
||||
var exp time.Time
|
||||
if cfg.oidcTokenExpires != nil {
|
||||
exp = cfg.oidcTokenExpires()
|
||||
}
|
||||
|
||||
return (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
AccessToken: "token",
|
||||
RefreshToken: cfg.refreshToken,
|
||||
Expiry: exp,
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": string(token),
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
func (cfg *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := claims["exp"]; !ok {
|
||||
@ -1069,20 +1116,20 @@ func (o *OIDCConfig) EncodeClaims(t *testing.T, claims jwt.MapClaims) string {
|
||||
}
|
||||
|
||||
if _, ok := claims["iss"]; !ok {
|
||||
claims["iss"] = o.issuer
|
||||
claims["iss"] = cfg.issuer
|
||||
}
|
||||
|
||||
if _, ok := claims["sub"]; !ok {
|
||||
claims["sub"] = "testme"
|
||||
}
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(o.key)
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(cfg.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return base64.StdEncoding.EncodeToString([]byte(signed))
|
||||
}
|
||||
|
||||
func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
func (cfg *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts ...func(cfg *coderd.OIDCConfig)) *coderd.OIDCConfig {
|
||||
// By default, the provider can be empty.
|
||||
// This means it won't support any endpoints!
|
||||
provider := &oidc.Provider{}
|
||||
@ -1099,10 +1146,10 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
|
||||
}
|
||||
provider = cfg.NewProvider(context.Background())
|
||||
}
|
||||
cfg := &coderd.OIDCConfig{
|
||||
OAuth2Config: o,
|
||||
Verifier: oidc.NewVerifier(o.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{o.key.Public()},
|
||||
newCFG := &coderd.OIDCConfig{
|
||||
OAuth2Config: cfg,
|
||||
Verifier: oidc.NewVerifier(cfg.issuer, &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{cfg.key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
}),
|
||||
@ -1113,9 +1160,9 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
|
||||
GroupField: "groups",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
opt(newCFG)
|
||||
}
|
||||
return cfg
|
||||
return newCFG
|
||||
}
|
||||
|
||||
// NewAzureInstanceIdentity returns a metadata client and ID token validator for faking
|
||||
|
Reference in New Issue
Block a user