mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
feat: allow configuring OIDC email claim and OIDC auth url parameters (#6867)
This commit: - Allows configuring the OIDC claim Coder uses for email addresses (by default, this is still email) - Allows customising the parameters sent to the upstream identity provider when requesting a token. This is still access_type=offline by default. - Updates documentation related to the above.
This commit is contained in:
6
coderd/apidoc/docs.go
generated
6
coderd/apidoc/docs.go
generated
@ -7294,6 +7294,9 @@ const docTemplate = `{
|
||||
"allow_signups": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"auth_url_params": {
|
||||
"type": "object"
|
||||
},
|
||||
"client_id": {
|
||||
"type": "string"
|
||||
},
|
||||
@ -7306,6 +7309,9 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"email_field": {
|
||||
"type": "string"
|
||||
},
|
||||
"group_mapping": {
|
||||
"type": "object"
|
||||
},
|
||||
|
6
coderd/apidoc/swagger.json
generated
6
coderd/apidoc/swagger.json
generated
@ -6532,6 +6532,9 @@
|
||||
"allow_signups": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"auth_url_params": {
|
||||
"type": "object"
|
||||
},
|
||||
"client_id": {
|
||||
"type": "string"
|
||||
},
|
||||
@ -6544,6 +6547,9 @@
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"email_field": {
|
||||
"type": "string"
|
||||
},
|
||||
"group_mapping": {
|
||||
"type": "object"
|
||||
},
|
||||
|
@ -301,6 +301,12 @@ func New(options *Options) *API {
|
||||
*options.UpdateCheckOptions,
|
||||
)
|
||||
}
|
||||
|
||||
var oidcAuthURLParams map[string]string
|
||||
if options.OIDCConfig != nil {
|
||||
oidcAuthURLParams = options.OIDCConfig.AuthURLParams
|
||||
}
|
||||
|
||||
api.Auditor.Store(&options.Auditor)
|
||||
api.TemplateScheduleStore.Store(&options.TemplateScheduleStore)
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
@ -387,7 +393,7 @@ func New(options *Options) *API {
|
||||
for _, gitAuthConfig := range options.GitAuthConfigs {
|
||||
r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient),
|
||||
httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient, nil),
|
||||
apiKeyMiddleware,
|
||||
)
|
||||
r.Get("/callback", api.gitAuthCallback(gitAuthConfig))
|
||||
@ -531,12 +537,12 @@ func New(options *Options) *API {
|
||||
r.Post("/login", api.postLogin)
|
||||
r.Route("/oauth2", func(r chi.Router) {
|
||||
r.Route("/github", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient))
|
||||
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil))
|
||||
r.Get("/callback", api.userOAuth2Github)
|
||||
})
|
||||
})
|
||||
r.Route("/oidc/callback", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient))
|
||||
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams))
|
||||
r.Get("/", api.userOIDC)
|
||||
})
|
||||
})
|
||||
|
@ -967,6 +967,8 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
|
||||
}),
|
||||
Provider: provider,
|
||||
UsernameField: "preferred_username",
|
||||
EmailField: "email",
|
||||
AuthURLParams: map[string]string{"access_type": "offline"},
|
||||
GroupField: "groups",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
|
@ -21,6 +21,8 @@ func TestDeploymentValues(t *testing.T) {
|
||||
// values should not be returned
|
||||
cfg.OAuth2.Github.ClientSecret.Set(hi)
|
||||
cfg.OIDC.ClientSecret.Set(hi)
|
||||
cfg.OIDC.AuthURLParams.Set(`{"foo":"bar"}`)
|
||||
cfg.OIDC.EmailField.Set("some_random_field_you_never_expected")
|
||||
cfg.PostgresURL.Set(hi)
|
||||
cfg.SCIMAPIKey.Set(hi)
|
||||
|
||||
@ -32,6 +34,10 @@ func TestDeploymentValues(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
// ensure normal values pass through
|
||||
require.EqualValues(t, true, scrubbed.Values.BrowserOnly.Value())
|
||||
require.NotEmpty(t, cfg.OIDC.AuthURLParams)
|
||||
require.EqualValues(t, cfg.OIDC.AuthURLParams, scrubbed.Values.OIDC.AuthURLParams)
|
||||
require.NotEmpty(t, cfg.OIDC.EmailField)
|
||||
require.EqualValues(t, cfg.OIDC.EmailField, scrubbed.Values.OIDC.EmailField)
|
||||
// ensure secrets are removed
|
||||
require.Empty(t, scrubbed.Values.OAuth2.Github.ClientSecret.Value())
|
||||
require.Empty(t, scrubbed.Values.OIDC.ClientSecret.Value())
|
||||
|
@ -40,7 +40,15 @@ func OAuth2(r *http.Request) OAuth2State {
|
||||
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
|
||||
// URLs, and handling the exchange inbound. Any route that does not have
|
||||
// a "code" URL parameter will be redirected.
|
||||
func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) http.Handler {
|
||||
// AuthURLOpts are passed to the AuthCodeURL function. If this is nil,
|
||||
// the default option oauth2.AccessTypeOffline will be used.
|
||||
func ExtractOAuth2(config OAuth2Config, client *http.Client, authURLOpts map[string]string) func(http.Handler) http.Handler {
|
||||
opts := make([]oauth2.AuthCodeOption, 0, len(authURLOpts)+1)
|
||||
opts = append(opts, oauth2.AccessTypeOffline)
|
||||
for k, v := range authURLOpts {
|
||||
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@ -109,7 +117,7 @@ func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler)
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
|
||||
http.Redirect(rw, r, config.AuthCodeURL(state, opts...), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -15,9 +15,13 @@ import (
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
type testOAuth2Provider struct{}
|
||||
type testOAuth2Provider struct {
|
||||
t testing.TB
|
||||
authOpts []oauth2.AuthCodeOption
|
||||
}
|
||||
|
||||
func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
func (p *testOAuth2Provider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
assert.EqualValues(p.t, p.authOpts, opts)
|
||||
return "?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
@ -31,6 +35,13 @@ func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestOAuth2Provider(t testing.TB, opts ...oauth2.AuthCodeOption) *testOAuth2Provider {
|
||||
return &testOAuth2Provider{
|
||||
t: t,
|
||||
authOpts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:bodyclose
|
||||
func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -38,14 +49,15 @@ func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(nil, nil, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("RedirectWithoutCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
if !assert.NotEmpty(t, location) {
|
||||
return
|
||||
@ -58,14 +70,16 @@ func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("NoStateCookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("MismatchedState", func(t *testing.T) {
|
||||
@ -76,7 +90,8 @@ func TestOAuth2(t *testing.T) {
|
||||
Value: "mismatch",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("ExchangeCodeAndState", func(t *testing.T) {
|
||||
@ -91,9 +106,23 @@ func TestOAuth2(t *testing.T) {
|
||||
Value: "/dashboard",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
require.Equal(t, "/dashboard", state.Redirect)
|
||||
})).ServeHTTP(res, req)
|
||||
})
|
||||
t.Run("CustomAuthCodeOptions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("foo", "bar"))
|
||||
authOpts := map[string]string{"foo": "bar"}
|
||||
httpmw.ExtractOAuth2(tp, nil, authOpts)(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
// Ideally we would also assert that the location contains the query params
|
||||
// we set in the auth URL but this would essentially be testing the oauth2 package.
|
||||
// testOAuth2Provider does this job for us.
|
||||
require.NotEmpty(t, location)
|
||||
})
|
||||
}
|
||||
|
@ -477,6 +477,12 @@ type OIDCConfig struct {
|
||||
// UsernameField selects the claim field to be used as the created user's
|
||||
// username.
|
||||
UsernameField string
|
||||
// EmailField selects the claim field to be used as the created user's
|
||||
// email.
|
||||
EmailField string
|
||||
// AuthURLParams are additional parameters to be passed to the OIDC provider
|
||||
// when requesting an access token.
|
||||
AuthURLParams map[string]string
|
||||
// GroupField selects the claim field to be used as the created user's
|
||||
// groups. If the group field is the empty string, then no group updates
|
||||
// will ever come from the OIDC provider.
|
||||
@ -593,7 +599,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
username, _ = usernameRaw.(string)
|
||||
}
|
||||
|
||||
emailRaw, ok := claims["email"]
|
||||
emailRaw, ok := claims[api.OIDCConfig.EmailField]
|
||||
if !ok {
|
||||
// Email is an optional claim in OIDC and
|
||||
// instead the email is frequently sent in
|
||||
|
Reference in New Issue
Block a user