mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
feat: secure and cross-domain subdomain-based proxying (#4136)
Co-authored-by: Kyle Carberry <kyle@carberry.com>
This commit is contained in:
@ -13,25 +13,38 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tabbed/pqtype"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// The special cookie name used for subdomain-based application proxying.
|
||||
// TODO: this will make dogfooding harder so come up with a more unique
|
||||
// solution
|
||||
//
|
||||
//nolint:gosec
|
||||
const DevURLSessionTokenCookie = "coder_devurl_session_token"
|
||||
|
||||
type apiKeyContextKey struct{}
|
||||
|
||||
// APIKeyOptional may return an API key from the ExtractAPIKey handler.
|
||||
func APIKeyOptional(r *http.Request) (database.APIKey, bool) {
|
||||
key, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||
return key, ok
|
||||
}
|
||||
|
||||
// APIKey returns the API key from the ExtractAPIKey handler.
|
||||
func APIKey(r *http.Request) database.APIKey {
|
||||
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||
key, ok := APIKeyOptional(r)
|
||||
if !ok {
|
||||
panic("developer error: apikey middleware not provided")
|
||||
panic("developer error: ExtractAPIKey middleware not provided")
|
||||
}
|
||||
return apiKey
|
||||
return key
|
||||
}
|
||||
|
||||
// User roles are the 'subject' field of Authorize()
|
||||
@ -44,10 +57,17 @@ type Authorization struct {
|
||||
Scope database.APIKeyScope
|
||||
}
|
||||
|
||||
// UserAuthorizationOptional may return the roles and scope used for
|
||||
// authorization. Depends on the ExtractAPIKey handler.
|
||||
func UserAuthorizationOptional(r *http.Request) (Authorization, bool) {
|
||||
auth, ok := r.Context().Value(userAuthKey{}).(Authorization)
|
||||
return auth, ok
|
||||
}
|
||||
|
||||
// UserAuthorization returns the roles and scope used for authorization. Depends
|
||||
// on the ExtractAPIKey handler.
|
||||
func UserAuthorization(r *http.Request) Authorization {
|
||||
auth, ok := r.Context().Value(userAuthKey{}).(Authorization)
|
||||
auth, ok := UserAuthorizationOptional(r)
|
||||
if !ok {
|
||||
panic("developer error: ExtractAPIKey middleware not provided")
|
||||
}
|
||||
@ -66,63 +86,51 @@ const (
|
||||
internalErrorMessage string = "An internal error occurred. Please try again or contact the system administrator."
|
||||
)
|
||||
|
||||
type loginURLKey struct{}
|
||||
type ExtractAPIKeyConfig struct {
|
||||
DB database.Store
|
||||
OAuth2Configs *OAuth2Configs
|
||||
RedirectToLogin bool
|
||||
|
||||
func getLoginURL(r *http.Request) (*url.URL, bool) {
|
||||
val, ok := r.Context().Value(loginURLKey{}).(*url.URL)
|
||||
return val, ok
|
||||
// Optional governs whether the API key is optional. Use this if you want to
|
||||
// allow unauthenticated requests.
|
||||
//
|
||||
// If true and no session token is provided, nothing will be written to the
|
||||
// request context. Use the APIKeyOptional and UserAuthorizationOptional
|
||||
// functions to retrieve the API key and authorization instead of the
|
||||
// regular ones.
|
||||
//
|
||||
// If true and the API key is invalid (i.e. deleted, expired), the cookie
|
||||
// will be deleted and the request will continue. If the request is not a
|
||||
// cookie-based request, the request will be rejected with a 401.
|
||||
Optional bool
|
||||
}
|
||||
|
||||
// UseLoginURL sets the login URL to use for the request for handlers like
|
||||
// ExtractAPIKey.
|
||||
func UseLoginURL(loginURL *url.URL) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), loginURLKey{}, loginURL)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractAPIKey requires authentication using a valid API key.
|
||||
// It handles extending an API key if it comes close to expiry,
|
||||
// updating the last used time in the database.
|
||||
// ExtractAPIKey requires authentication using a valid API key. It handles
|
||||
// extending an API key if it comes close to expiry, updating the last used time
|
||||
// in the database.
|
||||
// nolint:revive
|
||||
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool) func(http.Handler) http.Handler {
|
||||
func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
// Write wraps writing a response to redirect if the handler
|
||||
// specified it should. This redirect is used for user-facing
|
||||
// pages like workspace applications.
|
||||
// specified it should. This redirect is used for user-facing pages
|
||||
// like workspace applications.
|
||||
write := func(code int, response codersdk.Response) {
|
||||
if redirectToLogin {
|
||||
var (
|
||||
u = &url.URL{
|
||||
Path: "/login",
|
||||
}
|
||||
redirectURL = func() string {
|
||||
path := r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
path += "?" + r.URL.RawQuery
|
||||
}
|
||||
return path
|
||||
}()
|
||||
)
|
||||
if loginURL, ok := getLoginURL(r); ok {
|
||||
u = loginURL
|
||||
// Don't redirect to the current page, as it may be on
|
||||
// a different domain and we have issues determining the
|
||||
// scheme to redirect to.
|
||||
redirectURL = ""
|
||||
if cfg.RedirectToLogin {
|
||||
path := r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
path += "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
q := url.Values{}
|
||||
q.Add("message", response.Message)
|
||||
if redirectURL != "" {
|
||||
q.Add("redirect", redirectURL)
|
||||
q.Add("redirect", path)
|
||||
|
||||
u := &url.URL{
|
||||
Path: "/login",
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(rw, r, u.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
@ -131,44 +139,42 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
httpapi.Write(ctx, rw, code, response)
|
||||
}
|
||||
|
||||
cookieValue := apiTokenFromRequest(r)
|
||||
if cookieValue == "" {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
// optionalWrite wraps write, but will pass the request on to the
|
||||
// next handler if the configuration says the API key is optional.
|
||||
//
|
||||
// It should be used when the API key is not provided or is invalid,
|
||||
// but not when there are other errors.
|
||||
optionalWrite := func(code int, response codersdk.Response) {
|
||||
if cfg.Optional {
|
||||
next.ServeHTTP(rw, r)
|
||||
return
|
||||
}
|
||||
|
||||
write(code, response)
|
||||
}
|
||||
|
||||
token := apiTokenFromRequest(r)
|
||||
if token == "" {
|
||||
optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: signedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Cookie %q or query parameter must be provided.", codersdk.SessionTokenKey),
|
||||
})
|
||||
return
|
||||
}
|
||||
parts := strings.Split(cookieValue, "-")
|
||||
// APIKeys are formatted: ID-SECRET
|
||||
if len(parts) != 2 {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
|
||||
keyID, keySecret, err := SplitAPIToken(token)
|
||||
if err != nil {
|
||||
optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: signedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Invalid %q cookie API key format.", codersdk.SessionTokenKey),
|
||||
Detail: "Invalid API key format: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
keyID := parts[0]
|
||||
keySecret := parts[1]
|
||||
// Ensuring key lengths are valid.
|
||||
if len(keyID) != 10 {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: signedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Invalid %q cookie API key id.", codersdk.SessionTokenKey),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(keySecret) != 22 {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: signedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("Invalid %q cookie API key secret.", codersdk.SessionTokenKey),
|
||||
})
|
||||
return
|
||||
}
|
||||
key, err := db.GetAPIKeyByID(r.Context(), keyID)
|
||||
|
||||
key, err := cfg.DB.GetAPIKeyByID(r.Context(), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: signedOutErrorMessage,
|
||||
Detail: "API key is invalid.",
|
||||
})
|
||||
@ -180,23 +186,25 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
})
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
// Checking to see if the secret is valid.
|
||||
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
hashedSecret := sha256.Sum256([]byte(keySecret))
|
||||
if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 {
|
||||
optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: signedOutErrorMessage,
|
||||
Detail: "API key secret is invalid.",
|
||||
})
|
||||
return
|
||||
}
|
||||
now := database.Now()
|
||||
// Tracks if the API key has properties updated!
|
||||
changed := false
|
||||
|
||||
var link database.UserLink
|
||||
var (
|
||||
link database.UserLink
|
||||
now = database.Now()
|
||||
// Tracks if the API key has properties updated
|
||||
changed = false
|
||||
)
|
||||
if key.LoginType != database.LoginTypePassword {
|
||||
link, err = db.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
link, err = cfg.DB.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
@ -207,14 +215,14 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
})
|
||||
return
|
||||
}
|
||||
// Check if the OAuth token is expired!
|
||||
// Check if the OAuth token is expired
|
||||
if link.OAuthExpiry.Before(now) && !link.OAuthExpiry.IsZero() {
|
||||
var oauthConfig OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = oauth.Github
|
||||
oauthConfig = cfg.OAuth2Configs.Github
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = oauth.OIDC
|
||||
oauthConfig = cfg.OAuth2Configs.OIDC
|
||||
default:
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
@ -222,7 +230,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
})
|
||||
return
|
||||
}
|
||||
// If it is, let's refresh it from the provided config!
|
||||
// If it is, let's refresh it from the provided config
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
RefreshToken: link.OAuthRefreshToken,
|
||||
@ -245,7 +253,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
|
||||
// Checking if the key is expired.
|
||||
if key.ExpiresAt.Before(now) {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: signedOutErrorMessage,
|
||||
Detail: fmt.Sprintf("API key expired at %q.", key.ExpiresAt.String()),
|
||||
})
|
||||
@ -278,7 +286,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
err := cfg.DB.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
@ -294,7 +302,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
// If the API Key is associated with a user_link (e.g. Github/OIDC)
|
||||
// then we want to update the relevant oauth fields.
|
||||
if link.UserID != uuid.Nil {
|
||||
link, err = db.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
|
||||
link, err = cfg.DB.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
@ -314,7 +322,7 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
// If the key is valid, we also fetch the user roles and status.
|
||||
// The roles are used for RBAC authorize checks, and the status
|
||||
// is to block 'suspended' users from accessing the platform.
|
||||
roles, err := db.GetAuthorizationUserRoles(r.Context(), key.UserID)
|
||||
roles, err := cfg.DB.GetAuthorizationUserRoles(r.Context(), key.UserID)
|
||||
if err != nil {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
@ -346,9 +354,10 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
// apiTokenFromRequest returns the api token from the request.
|
||||
// Find the session token from:
|
||||
// 1: The cookie
|
||||
// 2: The old cookie
|
||||
// 3. The coder_session_token query parameter
|
||||
// 4. The custom auth header
|
||||
// 1: The devurl cookie
|
||||
// 3: The old cookie
|
||||
// 4. The coder_session_token query parameter
|
||||
// 5. The custom auth header
|
||||
func apiTokenFromRequest(r *http.Request) string {
|
||||
cookie, err := r.Cookie(codersdk.SessionTokenKey)
|
||||
if err == nil && cookie.Value != "" {
|
||||
@ -373,5 +382,33 @@ func apiTokenFromRequest(r *http.Request) string {
|
||||
return headerValue
|
||||
}
|
||||
|
||||
cookie, err = r.Cookie(DevURLSessionTokenCookie)
|
||||
if err == nil && cookie.Value != "" {
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// SplitAPIToken verifies the format of an API key and returns the split ID and
|
||||
// secret.
|
||||
//
|
||||
// APIKeys are formatted: ${ID}-${SECRET}
|
||||
func SplitAPIToken(token string) (id string, secret string, err error) {
|
||||
parts := strings.Split(token, "-")
|
||||
if len(parts) != 2 {
|
||||
return "", "", xerrors.Errorf("incorrect amount of API key parts, expected 2 got %d", len(parts))
|
||||
}
|
||||
|
||||
// Ensure key lengths are valid.
|
||||
keyID := parts[0]
|
||||
keySecret := parts[1]
|
||||
if len(keyID) != 10 {
|
||||
return "", "", xerrors.Errorf("invalid API key ID length, expected 10 got %d", len(keyID))
|
||||
}
|
||||
if len(keySecret) != 22 {
|
||||
return "", "", xerrors.Errorf("invalid API key secret length, expected 22 got %d", len(keySecret))
|
||||
}
|
||||
|
||||
return keyID, keySecret, nil
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -46,7 +47,10 @@ func TestAPIKey(t *testing.T) {
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
@ -59,7 +63,10 @@ func TestAPIKey(t *testing.T) {
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
httpmw.ExtractAPIKey(db, nil, true)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: true,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
location, err := res.Location()
|
||||
@ -77,7 +84,10 @@ func TestAPIKey(t *testing.T) {
|
||||
)
|
||||
r.Header.Set(codersdk.SessionCustomHeader, "test-wow-hello")
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
@ -92,7 +102,10 @@ func TestAPIKey(t *testing.T) {
|
||||
)
|
||||
r.Header.Set(codersdk.SessionCustomHeader, "test-wow")
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
@ -107,7 +120,10 @@ func TestAPIKey(t *testing.T) {
|
||||
)
|
||||
r.Header.Set(codersdk.SessionCustomHeader, "testtestid-wow")
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
@ -123,7 +139,10 @@ func TestAPIKey(t *testing.T) {
|
||||
)
|
||||
r.Header.Set(codersdk.SessionCustomHeader, fmt.Sprintf("%s-%s", id, secret))
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
@ -149,7 +168,10 @@ func TestAPIKey(t *testing.T) {
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
@ -175,7 +197,10 @@ func TestAPIKey(t *testing.T) {
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
@ -202,7 +227,10 @@ func TestAPIKey(t *testing.T) {
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
// Checks that it exists on the context!
|
||||
_ = httpmw.APIKey(r)
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
|
||||
@ -244,7 +272,10 @@ func TestAPIKey(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
// Checks that it exists on the context!
|
||||
apiKey := httpmw.APIKey(r)
|
||||
assert.Equal(t, database.APIKeyScopeApplicationConnect, apiKey.Scope)
|
||||
@ -282,7 +313,10 @@ func TestAPIKey(t *testing.T) {
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
// Checks that it exists on the context!
|
||||
_ = httpmw.APIKey(r)
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
|
||||
@ -316,7 +350,10 @@ func TestAPIKey(t *testing.T) {
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
@ -350,7 +387,10 @@ func TestAPIKey(t *testing.T) {
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
@ -391,7 +431,10 @@ func TestAPIKey(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
@ -436,13 +479,17 @@ func TestAPIKey(t *testing.T) {
|
||||
RefreshToken: "moo",
|
||||
Expiry: database.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{
|
||||
Github: &oauth2Config{
|
||||
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
|
||||
return token, nil
|
||||
}),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
OAuth2Configs: &httpmw.OAuth2Configs{
|
||||
Github: &oauth2Config{
|
||||
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
|
||||
return token, nil
|
||||
}),
|
||||
},
|
||||
},
|
||||
}, false)(successHandler).ServeHTTP(rw, r)
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
@ -477,7 +524,10 @@ func TestAPIKey(t *testing.T) {
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r)
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
@ -487,6 +537,58 @@ func TestAPIKey(t *testing.T) {
|
||||
|
||||
require.Equal(t, net.ParseIP("1.1.1.1"), gotAPIKey.IPAddress.IPNet.IP)
|
||||
})
|
||||
|
||||
t.Run("RedirectToLogin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: true,
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode)
|
||||
u, err := res.Location()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/login", u.Path)
|
||||
})
|
||||
|
||||
t.Run("Optional", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
|
||||
count int64
|
||||
handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt64(&count, 1)
|
||||
|
||||
apiKey, ok := httpmw.APIKeyOptional(r)
|
||||
assert.False(t, ok)
|
||||
assert.Zero(t, apiKey)
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
)
|
||||
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
Optional: true,
|
||||
})(handler).ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
require.EqualValues(t, 1, atomic.LoadInt64(&count))
|
||||
})
|
||||
}
|
||||
|
||||
func createUser(ctx context.Context, t *testing.T, db database.Store) database.User {
|
||||
|
@ -84,7 +84,11 @@ func TestExtractUserRoles(t *testing.T) {
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
OAuth2Configs: &httpmw.OAuth2Configs{},
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
)
|
||||
rtr.Get("/", func(_ http.ResponseWriter, r *http.Request) {
|
||||
roles := httpmw.UserAuthorization(r)
|
||||
|
@ -67,7 +67,10 @@ func TestOrganizationParam(t *testing.T) {
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
@ -87,7 +90,10 @@ func TestOrganizationParam(t *testing.T) {
|
||||
)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.NewString())
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
@ -107,7 +113,10 @@ func TestOrganizationParam(t *testing.T) {
|
||||
)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", "not-a-uuid")
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
@ -135,7 +144,10 @@ func TestOrganizationParam(t *testing.T) {
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
|
||||
chi.RouteContext(r.Context()).URLParams.Add("user", u.ID.String())
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractUserParam(db),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
httpmw.ExtractOrganizationMemberParam(db),
|
||||
@ -172,7 +184,10 @@ func TestOrganizationParam(t *testing.T) {
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID.String())
|
||||
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
httpmw.ExtractUserParam(db),
|
||||
httpmw.ExtractOrganizationMemberParam(db),
|
||||
|
@ -132,7 +132,10 @@ func TestTemplateParam(t *testing.T) {
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractTemplateParam(db),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
|
@ -124,7 +124,10 @@ func TestTemplateVersionParam(t *testing.T) {
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractTemplateVersionParam(db),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
|
@ -56,7 +56,10 @@ func TestUserParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
r = returnedRequest
|
||||
})).ServeHTTP(rw, r)
|
||||
|
||||
@ -72,7 +75,10 @@ func TestUserParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
r = returnedRequest
|
||||
})).ServeHTTP(rw, r)
|
||||
|
||||
@ -91,7 +97,10 @@ func TestUserParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
})(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
r = returnedRequest
|
||||
})).ServeHTTP(rw, r)
|
||||
|
||||
|
@ -132,7 +132,10 @@ func TestWorkspaceAgentParam(t *testing.T) {
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractWorkspaceAgentParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -107,7 +107,10 @@ func TestWorkspaceBuildParam(t *testing.T) {
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractWorkspaceBuildParam(db),
|
||||
httpmw.ExtractWorkspaceParam(db),
|
||||
)
|
||||
|
@ -100,7 +100,10 @@ func TestWorkspaceParam(t *testing.T) {
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, false),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: false,
|
||||
}),
|
||||
httpmw.ExtractWorkspaceParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -298,7 +301,10 @@ func TestWorkspaceAgentByNameParam(t *testing.T) {
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil, true),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
RedirectToLogin: true,
|
||||
}),
|
||||
httpmw.ExtractUserParam(db),
|
||||
httpmw.ExtractWorkspaceAndAgentParam(db),
|
||||
)
|
||||
|
Reference in New Issue
Block a user