mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
fix: Add client certs to OAuth HTTPClient context (#5126)
This commit is contained in:
@ -107,6 +107,7 @@ type Options struct {
|
||||
Experimental bool
|
||||
DeploymentConfig *codersdk.DeploymentConfig
|
||||
UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking.
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// New constructs a Coder API handler.
|
||||
@ -279,7 +280,7 @@ func New(options *Options) *API {
|
||||
for _, gitAuthConfig := range options.GitAuthConfigs {
|
||||
r.Route(fmt.Sprintf("/%s", gitAuthConfig.ID), func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractOAuth2(gitAuthConfig),
|
||||
httpmw.ExtractOAuth2(gitAuthConfig, options.HTTPClient),
|
||||
apiKeyMiddleware,
|
||||
)
|
||||
r.Get("/callback", api.gitAuthCallback(gitAuthConfig))
|
||||
@ -428,12 +429,12 @@ func New(options *Options) *API {
|
||||
r.Get("/authmethods", api.userAuthMethods)
|
||||
r.Route("/oauth2", func(r chi.Router) {
|
||||
r.Route("/github", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
|
||||
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config, options.HTTPClient))
|
||||
r.Get("/callback", api.userOAuth2Github)
|
||||
})
|
||||
})
|
||||
r.Route("/oidc/callback", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig))
|
||||
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig, options.HTTPClient))
|
||||
r.Get("/", api.userOIDC)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
|
@ -40,10 +40,14 @@ func OAuth2(r *http.Request) OAuth2State {
|
||||
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
|
||||
// URLs, and handling the exchange inbound. Any route that does not have
|
||||
// a "code" URL parameter will be redirected.
|
||||
func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler {
|
||||
func ExtractOAuth2(config OAuth2Config, client *http.Client) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if client != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
}
|
||||
|
||||
// Interfaces can hold a nil value
|
||||
if config == nil || reflect.ValueOf(config).IsNil() {
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionRequired, codersdk.Response{
|
||||
|
@ -39,14 +39,14 @@ func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(nil, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("RedirectWithoutCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
if !assert.NotEmpty(t, location) {
|
||||
return
|
||||
@ -59,14 +59,14 @@ func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("NoStateCookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("MismatchedState", func(t *testing.T) {
|
||||
@ -77,7 +77,7 @@ func TestOAuth2(t *testing.T) {
|
||||
Value: "mismatch",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("ExchangeCodeAndState", func(t *testing.T) {
|
||||
@ -92,7 +92,7 @@ func TestOAuth2(t *testing.T) {
|
||||
Value: "/dashboard",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{}, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
require.Equal(t, "/dashboard", state.Redirect)
|
||||
})).ServeHTTP(res, req)
|
||||
|
Reference in New Issue
Block a user