mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
chore: add custom samesite options to auth cookies (#16885)
Allows controlling `samesite` cookie settings from the deployment config
This commit is contained in:
@ -16,10 +16,10 @@ import (
|
||||
// for non-GET requests.
|
||||
// If enforce is false, then CSRF enforcement is disabled. We still want
|
||||
// to include the CSRF middleware because it will set the CSRF cookie.
|
||||
func CSRF(secureCookie bool) func(next http.Handler) http.Handler {
|
||||
func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
mw := nosurf.New(next)
|
||||
mw.SetBaseCookie(http.Cookie{Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, Secure: secureCookie})
|
||||
mw.SetBaseCookie(*cookieCfg.Apply(&http.Cookie{Path: "/", HttpOnly: true}))
|
||||
mw.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sessCookie, err := r.Cookie(codersdk.SessionTokenCookie)
|
||||
if err == nil &&
|
||||
|
@ -53,7 +53,7 @@ func TestCSRFExemptList(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
mw := httpmw.CSRF(false)
|
||||
mw := httpmw.CSRF(codersdk.HTTPCookieConfig{})
|
||||
csrfmw := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).(*nosurf.CSRFHandler)
|
||||
|
||||
for _, c := range cases {
|
||||
@ -87,7 +87,7 @@ func TestCSRFError(t *testing.T) {
|
||||
var handler http.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler = httpmw.CSRF(false)(handler)
|
||||
handler = httpmw.CSRF(codersdk.HTTPCookieConfig{})(handler)
|
||||
|
||||
// Not testing the error case, just providing the example of things working
|
||||
// to base the failure tests off of.
|
||||
|
@ -40,7 +40,7 @@ func OAuth2(r *http.Request) OAuth2State {
|
||||
// a "code" URL parameter will be redirected.
|
||||
// AuthURLOpts are passed to the AuthCodeURL function. If this is nil,
|
||||
// the default option oauth2.AccessTypeOffline will be used.
|
||||
func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOpts map[string]string) func(http.Handler) http.Handler {
|
||||
func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, cookieCfg codersdk.HTTPCookieConfig, authURLOpts map[string]string) func(http.Handler) http.Handler {
|
||||
opts := make([]oauth2.AuthCodeOption, 0, len(authURLOpts)+1)
|
||||
opts = append(opts, oauth2.AccessTypeOffline)
|
||||
for k, v := range authURLOpts {
|
||||
@ -118,22 +118,20 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp
|
||||
}
|
||||
}
|
||||
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
http.SetCookie(rw, cookieCfg.Apply(&http.Cookie{
|
||||
Name: codersdk.OAuth2StateCookie,
|
||||
Value: state,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}))
|
||||
// Redirect must always be specified, otherwise
|
||||
// an old redirect could apply!
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
http.SetCookie(rw, cookieCfg.Apply(&http.Cookie{
|
||||
Name: codersdk.OAuth2RedirectCookie,
|
||||
Value: redirect,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}))
|
||||
|
||||
http.Redirect(rw, r, config.AuthCodeURL(state, opts...), http.StatusTemporaryRedirect)
|
||||
return
|
||||
|
@ -50,7 +50,7 @@ func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(nil, nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(nil, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("RedirectWithoutCode", func(t *testing.T) {
|
||||
@ -58,7 +58,7 @@ func TestOAuth2(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
if !assert.NotEmpty(t, location) {
|
||||
return
|
||||
@ -82,7 +82,7 @@ func TestOAuth2(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape(uri.String()), nil)
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
if !assert.NotEmpty(t, location) {
|
||||
return
|
||||
@ -97,7 +97,7 @@ func TestOAuth2(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?code=something", nil)
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("NoStateCookie", func(t *testing.T) {
|
||||
@ -105,7 +105,7 @@ func TestOAuth2(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("MismatchedState", func(t *testing.T) {
|
||||
@ -117,7 +117,7 @@ func TestOAuth2(t *testing.T) {
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("ExchangeCodeAndState", func(t *testing.T) {
|
||||
@ -133,7 +133,7 @@ func TestOAuth2(t *testing.T) {
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
require.Equal(t, "/dashboard", state.Redirect)
|
||||
})).ServeHTTP(res, req)
|
||||
@ -144,7 +144,7 @@ func TestOAuth2(t *testing.T) {
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("foo", "bar"))
|
||||
authOpts := map[string]string{"foo": "bar"}
|
||||
httpmw.ExtractOAuth2(tp, nil, authOpts)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, authOpts)(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
// Ideally we would also assert that the location contains the query params
|
||||
// we set in the auth URL but this would essentially be testing the oauth2 package.
|
||||
@ -157,12 +157,17 @@ func TestOAuth2(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/?oidc_merge_state="+customState+"&redirect="+url.QueryEscape("/dashboard"), nil)
|
||||
res := httptest.NewRecorder()
|
||||
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
|
||||
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
|
||||
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{
|
||||
Secure: true,
|
||||
SameSite: "none",
|
||||
}, nil)(nil).ServeHTTP(res, req)
|
||||
|
||||
found := false
|
||||
for _, cookie := range res.Result().Cookies() {
|
||||
if cookie.Name == codersdk.OAuth2StateCookie {
|
||||
require.Equal(t, cookie.Value, customState, "expected state")
|
||||
require.Equal(t, true, cookie.Secure, "cookie set to secure")
|
||||
require.Equal(t, http.SameSiteNoneMode, cookie.SameSite, "same-site = none")
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user