fix: Redirect to login when unauthenticated and requesting a workspace app (#2903)

Fixes #2884.
This commit is contained in:
Kyle Carberry
2022-07-11 13:46:01 -05:00
committed by GitHub
parent 08d90f7b4f
commit 2c89e07e12
12 changed files with 94 additions and 46 deletions

View File

@ -103,10 +103,10 @@ func New(options *Options) *API {
siteHandler: site.Handler(site.FS(), binFS),
}
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
})
}
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
r.Use(
func(next http.Handler) http.Handler {
@ -121,7 +121,7 @@ func New(options *Options) *API {
apps := func(r chi.Router) {
r.Use(
httpmw.RateLimitPerMinute(options.APIRateLimit),
apiKeyMiddleware,
httpmw.ExtractAPIKey(options.Database, oauthConfigs, true),
httpmw.ExtractUserParam(api.Database),
)
r.HandleFunc("/*", api.workspaceAppsProxyPath)

View File

@ -56,9 +56,26 @@ type OAuth2Configs struct {
// ExtractAPIKey requires authentication using a valid API key.
// It handles extending an API key if it comes close to expiry,
// updating the last used time in the database.
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler {
// nolint:revive
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Write wraps writing a response to redirect if the handler
// specified it should. This redirect is used for user-facing
// pages like workspace applications.
write := func(code int, response httpapi.Response) {
if redirectToLogin {
q := r.URL.Query()
q.Add("message", response.Message)
q.Add("redirect", r.URL.Path+"?"+r.URL.RawQuery)
r.URL.RawQuery = q.Encode()
r.URL.Path = "/login"
http.Redirect(rw, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
httpapi.Write(rw, code, response)
}
var cookieValue string
cookie, err := r.Cookie(SessionTokenKey)
if err != nil {
@ -67,7 +84,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
cookieValue = cookie.Value
}
if cookieValue == "" {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Cookie %q or query parameter must be provided.", SessionTokenKey),
})
return
@ -75,7 +92,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
parts := strings.Split(cookieValue, "-")
// APIKeys are formatted: ID-SECRET
if len(parts) != 2 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Invalid %q cookie API key format.", SessionTokenKey),
})
return
@ -84,13 +101,13 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
keySecret := parts[1]
// Ensuring key lengths are valid.
if len(keyID) != 10 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Invalid %q cookie API key id.", SessionTokenKey),
})
return
}
if len(keySecret) != 22 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("Invalid %q cookie API key secret.", SessionTokenKey),
})
return
@ -98,12 +115,12 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
key, err := db.GetAPIKeyByID(r.Context(), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "API key is invalid.",
})
return
}
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
write(http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching API key by id.",
Detail: err.Error(),
})
@ -113,7 +130,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
// Checking to see if the secret is valid.
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "API key secret is invalid.",
})
return
@ -130,7 +147,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
case database.LoginTypeGithub:
oauthConfig = oauth.Github
default:
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
write(http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("Unexpected authentication type %q.", key.LoginType),
})
return
@ -142,7 +159,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
Expiry: key.OAuthExpiry,
}).Token()
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "Could not refresh expired Oauth token.",
Detail: err.Error(),
})
@ -158,7 +175,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
// Checking if the key is expired.
if key.ExpiresAt.Before(now) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
})
return
@ -200,7 +217,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
OAuthExpiry: key.OAuthExpiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
write(http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("API key couldn't update: %s.", err.Error()),
})
return
@ -212,7 +229,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
// is to block 'suspended' users from accessing the platform.
roles, err := db.GetAuthorizationUserRoles(r.Context(), key.UserID)
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: "Internal error fetching user's roles.",
Detail: err.Error(),
})
@ -220,7 +237,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
}
if roles.Status != database.UserStatusActive {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
write(http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("User is not active (status = %q). Contact an admin to reactivate your account.", roles.Status),
})
return

View File

@ -44,12 +44,28 @@ func TestAPIKey(t *testing.T) {
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("NoCookieRedirects", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKey(db, nil, true)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
location, err := res.Location()
require.NoError(t, err)
require.NotEmpty(t, location.Query().Get("message"))
require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode)
})
t.Run("InvalidFormat", func(t *testing.T) {
t.Parallel()
var (
@ -62,7 +78,7 @@ func TestAPIKey(t *testing.T) {
Value: "test-wow-hello",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@ -80,7 +96,7 @@ func TestAPIKey(t *testing.T) {
Value: "test-wow",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@ -98,7 +114,7 @@ func TestAPIKey(t *testing.T) {
Value: "testtestid-wow",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@ -117,7 +133,7 @@ func TestAPIKey(t *testing.T) {
Value: fmt.Sprintf("%s-%s", id, secret),
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@ -145,7 +161,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@ -172,7 +188,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
@ -200,7 +216,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Checks that it exists on the context!
_ = httpmw.APIKey(r)
httpapi.Write(rw, http.StatusOK, httpapi.Response{
@ -238,7 +254,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Checks that it exists on the context!
_ = httpmw.APIKey(r)
httpapi.Write(rw, http.StatusOK, httpapi.Response{
@ -273,7 +289,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
@ -308,7 +324,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
@ -344,7 +360,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
@ -391,7 +407,7 @@ func TestAPIKey(t *testing.T) {
return token, nil
}),
},
})(successHandler).ServeHTTP(rw, r)
}, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
@ -428,7 +444,7 @@ func TestAPIKey(t *testing.T) {
UserID: user.ID,
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)

View File

@ -83,7 +83,7 @@ func TestExtractUserRoles(t *testing.T) {
rtr = chi.NewRouter()
)
rtr.Use(
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}),
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}, false),
)
rtr.Get("/", func(_ http.ResponseWriter, r *http.Request) {
roles := httpmw.AuthorizationUserRoles(r)

View File

@ -67,7 +67,7 @@ func TestOrganizationParam(t *testing.T) {
rtr = chi.NewRouter()
)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
@ -87,7 +87,7 @@ func TestOrganizationParam(t *testing.T) {
)
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.NewString())
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
@ -107,7 +107,7 @@ func TestOrganizationParam(t *testing.T) {
)
chi.RouteContext(r.Context()).URLParams.Add("organization", "not-a-uuid")
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
@ -135,7 +135,7 @@ func TestOrganizationParam(t *testing.T) {
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
chi.RouteContext(r.Context()).URLParams.Add("user", u.ID.String())
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractUserParam(db),
httpmw.ExtractOrganizationParam(db),
httpmw.ExtractOrganizationMemberParam(db),
@ -172,7 +172,7 @@ func TestOrganizationParam(t *testing.T) {
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractOrganizationParam(db),
httpmw.ExtractUserParam(db),
httpmw.ExtractOrganizationMemberParam(db),

View File

@ -132,7 +132,7 @@ func TestTemplateParam(t *testing.T) {
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractTemplateParam(db),
httpmw.ExtractOrganizationParam(db),
)

View File

@ -124,7 +124,7 @@ func TestTemplateVersionParam(t *testing.T) {
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractTemplateVersionParam(db),
httpmw.ExtractOrganizationParam(db),
)

View File

@ -56,7 +56,7 @@ func TestUserParam(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
@ -72,7 +72,7 @@ func TestUserParam(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
@ -91,7 +91,7 @@ func TestUserParam(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)

View File

@ -132,7 +132,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractWorkspaceAgentParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {

View File

@ -107,7 +107,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractWorkspaceBuildParam(db),
httpmw.ExtractWorkspaceParam(db),
)

View File

@ -97,7 +97,7 @@ func TestWorkspaceParam(t *testing.T) {
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractAPIKey(db, nil, false),
httpmw.ExtractWorkspaceParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {

View File

@ -86,6 +86,21 @@ func TestWorkspaceAppsProxyPath(t *testing.T) {
return http.ErrUseLastResponse
}
t.Run("RedirectsWithoutAuth", func(t *testing.T) {
t.Parallel()
client := codersdk.New(client.URL)
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example", nil)
require.NoError(t, err)
defer resp.Body.Close()
location, err := resp.Location()
require.NoError(t, err)
require.Equal(t, "/login", location.Path)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
t.Run("RedirectsWithSlash", func(t *testing.T) {
t.Parallel()
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example", nil)