chore: add custom samesite options to auth cookies (#16885)

Allows controlling `samesite` cookie settings from the deployment config
This commit is contained in:
Steven Masley
2025-04-08 14:15:14 -05:00
committed by GitHub
parent 389e88ec82
commit 52d555880c
26 changed files with 240 additions and 67 deletions

View File

@ -16,10 +16,10 @@ import (
// for non-GET requests.
// If enforce is false, then CSRF enforcement is disabled. We still want
// to include the CSRF middleware because it will set the CSRF cookie.
func CSRF(secureCookie bool) func(next http.Handler) http.Handler {
func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
mw := nosurf.New(next)
mw.SetBaseCookie(http.Cookie{Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, Secure: secureCookie})
mw.SetBaseCookie(*cookieCfg.Apply(&http.Cookie{Path: "/", HttpOnly: true}))
mw.SetFailureHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessCookie, err := r.Cookie(codersdk.SessionTokenCookie)
if err == nil &&

View File

@ -53,7 +53,7 @@ func TestCSRFExemptList(t *testing.T) {
},
}
mw := httpmw.CSRF(false)
mw := httpmw.CSRF(codersdk.HTTPCookieConfig{})
csrfmw := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})).(*nosurf.CSRFHandler)
for _, c := range cases {
@ -87,7 +87,7 @@ func TestCSRFError(t *testing.T) {
var handler http.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
handler = httpmw.CSRF(false)(handler)
handler = httpmw.CSRF(codersdk.HTTPCookieConfig{})(handler)
// Not testing the error case, just providing the example of things working
// to base the failure tests off of.

View File

@ -40,7 +40,7 @@ func OAuth2(r *http.Request) OAuth2State {
// a "code" URL parameter will be redirected.
// AuthURLOpts are passed to the AuthCodeURL function. If this is nil,
// the default option oauth2.AccessTypeOffline will be used.
func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOpts map[string]string) func(http.Handler) http.Handler {
func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, cookieCfg codersdk.HTTPCookieConfig, authURLOpts map[string]string) func(http.Handler) http.Handler {
opts := make([]oauth2.AuthCodeOption, 0, len(authURLOpts)+1)
opts = append(opts, oauth2.AccessTypeOffline)
for k, v := range authURLOpts {
@ -118,22 +118,20 @@ func ExtractOAuth2(config promoauth.OAuth2Config, client *http.Client, authURLOp
}
}
http.SetCookie(rw, &http.Cookie{
http.SetCookie(rw, cookieCfg.Apply(&http.Cookie{
Name: codersdk.OAuth2StateCookie,
Value: state,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
}))
// Redirect must always be specified, otherwise
// an old redirect could apply!
http.SetCookie(rw, &http.Cookie{
http.SetCookie(rw, cookieCfg.Apply(&http.Cookie{
Name: codersdk.OAuth2RedirectCookie,
Value: redirect,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
}))
http.Redirect(rw, r, config.AuthCodeURL(state, opts...), http.StatusTemporaryRedirect)
return

View File

@ -50,7 +50,7 @@ func TestOAuth2(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(nil, nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(nil, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
})
t.Run("RedirectWithoutCode", func(t *testing.T) {
@ -58,7 +58,7 @@ func TestOAuth2(t *testing.T) {
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
if !assert.NotEmpty(t, location) {
return
@ -82,7 +82,7 @@ func TestOAuth2(t *testing.T) {
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape(uri.String()), nil)
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
if !assert.NotEmpty(t, location) {
return
@ -97,7 +97,7 @@ func TestOAuth2(t *testing.T) {
req := httptest.NewRequest("GET", "/?code=something", nil)
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
})
t.Run("NoStateCookie", func(t *testing.T) {
@ -105,7 +105,7 @@ func TestOAuth2(t *testing.T) {
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("MismatchedState", func(t *testing.T) {
@ -117,7 +117,7 @@ func TestOAuth2(t *testing.T) {
})
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("ExchangeCodeAndState", func(t *testing.T) {
@ -133,7 +133,7 @@ func TestOAuth2(t *testing.T) {
})
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, nil)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
require.Equal(t, "/dashboard", state.Redirect)
})).ServeHTTP(res, req)
@ -144,7 +144,7 @@ func TestOAuth2(t *testing.T) {
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("foo", "bar"))
authOpts := map[string]string{"foo": "bar"}
httpmw.ExtractOAuth2(tp, nil, authOpts)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{}, authOpts)(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
// Ideally we would also assert that the location contains the query params
// we set in the auth URL but this would essentially be testing the oauth2 package.
@ -157,12 +157,17 @@ func TestOAuth2(t *testing.T) {
req := httptest.NewRequest("GET", "/?oidc_merge_state="+customState+"&redirect="+url.QueryEscape("/dashboard"), nil)
res := httptest.NewRecorder()
tp := newTestOAuth2Provider(t, oauth2.AccessTypeOffline)
httpmw.ExtractOAuth2(tp, nil, nil)(nil).ServeHTTP(res, req)
httpmw.ExtractOAuth2(tp, nil, codersdk.HTTPCookieConfig{
Secure: true,
SameSite: "none",
}, nil)(nil).ServeHTTP(res, req)
found := false
for _, cookie := range res.Result().Cookies() {
if cookie.Name == codersdk.OAuth2StateCookie {
require.Equal(t, cookie.Value, customState, "expected state")
require.Equal(t, true, cookie.Secure, "cookie set to secure")
require.Equal(t, http.SameSiteNoneMode, cookie.SameSite, "same-site = none")
found = true
}
}