feat: allow configuring OIDC email claim and OIDC auth url parameters (#6867)

This commit:

- Allows configuring the OIDC claim Coder uses for email addresses (by default, this is still email)
- Allows customising the parameters sent to the upstream identity provider when requesting a token. This is still access_type=offline by default.
- Updates documentation related to the above.
This commit is contained in:
Cian Johnston
2023-03-30 09:36:57 +01:00
committed by GitHub
parent 6981f89cd8
commit 563c3ade06
17 changed files with 379 additions and 22 deletions

6
coderd/apidoc/docs.go generated
View File

@ -7294,6 +7294,9 @@ const docTemplate = `{
"allow_signups": {
"type": "boolean"
},
"auth_url_params": {
"type": "object"
},
"client_id": {
"type": "string"
},
@ -7306,6 +7309,9 @@ const docTemplate = `{
"type": "string"
}
},
"email_field": {
"type": "string"
},
"group_mapping": {
"type": "object"
},

View File

@ -6532,6 +6532,9 @@
"allow_signups": {
"type": "boolean"
},
"auth_url_params": {
"type": "object"
},
"client_id": {
"type": "string"
},
@ -6544,6 +6547,9 @@
"type": "string"
}
},
"email_field": {
"type": "string"
},
"group_mapping": {
"type": "object"
},

View File

@ -301,6 +301,12 @@ func New(options *Options) *API {
*options.UpdateCheckOptions,
)
}
var oidcAuthURLParams map[string]string
if options.OIDCConfig != nil {
oidcAuthURLParams = options.OIDCConfig.AuthURLParams
}
api.Auditor.Store(&options.Auditor)
api.TemplateScheduleStore.Store(&options.TemplateScheduleStore)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
@ -387,7 +393,7 @@ func New(options *Options) *API {
for _, gitAuthConfig := range options.GitAuthConfigs {
r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient),
httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient, nil),
apiKeyMiddleware,
)
r.Get("/callback", api.gitAuthCallback(gitAuthConfig))
@ -531,12 +537,12 @@ func New(options *Options) *API {
r.Post("/login", api.postLogin)
r.Route("/oauth2", func(r chi.Router) {
r.Route("/github", func(r chi.Router) {
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient))
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient, nil))
r.Get("/callback", api.userOAuth2Github)
})
})
r.Route("/oidc/callback", func(r chi.Router) {
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient))
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient, oidcAuthURLParams))
r.Get("/", api.userOIDC)
})
})

View File

@ -967,6 +967,8 @@ func (o *OIDCConfig) OIDCConfig(t *testing.T, userInfoClaims jwt.MapClaims, opts
}),
Provider: provider,
UsernameField: "preferred_username",
EmailField: "email",
AuthURLParams: map[string]string{"access_type": "offline"},
GroupField: "groups",
}
for _, opt := range opts {

View File

@ -21,6 +21,8 @@ func TestDeploymentValues(t *testing.T) {
// values should not be returned
cfg.OAuth2.Github.ClientSecret.Set(hi)
cfg.OIDC.ClientSecret.Set(hi)
cfg.OIDC.AuthURLParams.Set(`{"foo":"bar"}`)
cfg.OIDC.EmailField.Set("some_random_field_you_never_expected")
cfg.PostgresURL.Set(hi)
cfg.SCIMAPIKey.Set(hi)
@ -32,6 +34,10 @@ func TestDeploymentValues(t *testing.T) {
require.NoError(t, err)
// ensure normal values pass through
require.EqualValues(t, true, scrubbed.Values.BrowserOnly.Value())
require.NotEmpty(t, cfg.OIDC.AuthURLParams)
require.EqualValues(t, cfg.OIDC.AuthURLParams, scrubbed.Values.OIDC.AuthURLParams)
require.NotEmpty(t, cfg.OIDC.EmailField)
require.EqualValues(t, cfg.OIDC.EmailField, scrubbed.Values.OIDC.EmailField)
// ensure secrets are removed
require.Empty(t, scrubbed.Values.OAuth2.Github.ClientSecret.Value())
require.Empty(t, scrubbed.Values.OIDC.ClientSecret.Value())

View File

@ -40,7 +40,15 @@ func OAuth2(r *http.Request) OAuth2State {
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
// URLs, and handling the exchange inbound. Any route that does not have
// a "code" URL parameter will be redirected.
func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) http.Handler {
// AuthURLOpts are passed to the AuthCodeURL function. If this is nil,
// the default option oauth2.AccessTypeOffline will be used.
func ExtractOAuth2(config OAuth2Config, client *http.Client, authURLOpts map[string]string) func(http.Handler) http.Handler {
opts := make([]oauth2.AuthCodeOption, 0, len(authURLOpts)+1)
opts = append(opts, oauth2.AccessTypeOffline)
for k, v := range authURLOpts {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@ -109,7 +117,7 @@ func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler)
SameSite: http.SameSiteLaxMode,
})
http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
http.Redirect(rw, r, config.AuthCodeURL(state, opts...), http.StatusTemporaryRedirect)
return
}

View File

@ -15,9 +15,13 @@ import (
"github.com/coder/coder/codersdk"
)
type testOAuth2Provider struct{}
type testOAuth2Provider struct {
t testing.TB
authOpts []oauth2.AuthCodeOption
}
func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
func (p *testOAuth2Provider) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
assert.EqualValues(p.t, p.authOpts, opts)
return "?state=" + url.QueryEscape(state)
}
@ -31,6 +35,13 @@ func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth
return nil
}
func newTestOAuth2Provider(t testing.TB, opts ...oauth2.AuthCodeOption) *testOAuth2Provider {
return &testOAuth2Provider{
t: t,
authOpts: opts,
}
}
// nolint:bodyclose
func TestOAuth2(t *testing.T) {
t.Parallel()
@ -38,14 +49,15 @@ func TestOAuth2(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(nil, nil, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
})
t.Run("RedirectWithoutCode", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
if !assert.NotEmpty(t, location) {
return
@ -58,14 +70,16 @@ func TestOAuth2(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
})
t.Run("NoStateCookie", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("MismatchedState", func(t *testing.T) {
@ -76,7 +90,8 @@ func TestOAuth2(t *testing.T) {
Value: "mismatch",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("ExchangeCodeAndState", func(t *testing.T) {
@ -91,9 +106,23 @@ func TestOAuth2(t *testing.T) {
Value: "/dashboard",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
require.Equal(t, "/dashboard", state.Redirect)
})).ServeHTTP(res, req)
})
t.Run("CustomAuthCodeOptions", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("foo", "bar"))
authOpts := map[string]string{"foo": "bar"}
httpmw.ExtractOAuth2(tp, nil, authOpts)(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
// Ideally we would also assert that the location contains the query params
// we set in the auth URL but this would essentially be testing the oauth2 package.
// testOAuth2Provider does this job for us.
require.NotEmpty(t, location)
})
}

View File

@ -477,6 +477,12 @@ type OIDCConfig struct {
// UsernameField selects the claim field to be used as the created user's
// username.
UsernameField string
// EmailField selects the claim field to be used as the created user's
// email.
EmailField string
// AuthURLParams are additional parameters to be passed to the OIDC provider
// when requesting an access token.
AuthURLParams map[string]string
// GroupField selects the claim field to be used as the created user's
// groups. If the group field is the empty string, then no group updates
// will ever come from the OIDC provider.
@ -593,7 +599,7 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
username, _ = usernameRaw.(string)
}
emailRaw, ok := claims["email"]
emailRaw, ok := claims[api.OIDCConfig.EmailField]
if !ok {
// Email is an optional claim in OIDC and
// instead the email is frequently sent in