fix: Add client certs to OAuth HTTPClient context (#5126)

This commit is contained in:
Arthur Normand
2022-12-14 09:44:29 -05:00
committed by GitHub
parent 663f7a3f12
commit ad0dd1be5d
4 changed files with 55 additions and 29 deletions

View File

@ -107,6 +107,7 @@ type Options struct {
Experimental bool
DeploymentConfig *codersdk.DeploymentConfig
UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking.
HTTPClient *http.Client
}
// New constructs a Coder API handler.
@ -279,7 +280,7 @@ func New(options *Options) *API {
for _, gitAuthConfig := range options.GitAuthConfigs {
r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) {
r.Use(
httpmw.ExtractOAuth2(gitAuthConfig),
httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient),
apiKeyMiddleware,
)
r.Get("/callback", api.gitAuthCallback(gitAuthConfig))
@ -428,12 +429,12 @@ func New(options *Options) *API {
r.Get("/authmethods", api.userAuthMethods)
r.Route("/oauth2", func(r chi.Router) {
r.Route("/github", func(r chi.Router) {
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient))
r.Get("/callback", api.userOAuth2Github)
})
})
r.Route("/oidc/callback", func(r chi.Router) {
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig))
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient))
r.Get("/", api.userOIDC)
})
r.Group(func(r chi.Router) {

View File

@ -40,10 +40,14 @@ func OAuth2(r *http.Request) OAuth2State {
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
// URLs, and handling the exchange inbound. Any route that does not have
// a "code" URL parameter will be redirected.
func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler {
func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if client != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
}
// Interfaces can hold a nil value
if config == nil || reflect.ValueOf(config).IsNil() {
httpapi.Write(ctx, rw, http.StatusPreconditionRequired, codersdk.Response{

View File

@ -39,14 +39,14 @@ func TestOAuth2(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(nil, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode)
})
t.Run("RedirectWithoutCode", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
if !assert.NotEmpty(t, location) {
return
@ -59,14 +59,14 @@ func TestOAuth2(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
})
t.Run("NoStateCookie", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("MismatchedState", func(t *testing.T) {
@ -77,7 +77,7 @@ func TestOAuth2(t *testing.T) {
Value: "mismatch",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("ExchangeCodeAndState", func(t *testing.T) {
@ -92,7 +92,7 @@ func TestOAuth2(t *testing.T) {
Value: "/dashboard",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
require.Equal(t, "/dashboard", state.Redirect)
})).ServeHTTP(res, req)