mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: implement RFC 6750 Bearer token authentication (#18644)
# Add RFC 6750 Bearer Token Authentication Support This PR implements RFC 6750 Bearer Token authentication as an additional authentication method for Coder's API. This allows clients to authenticate using standard OAuth 2.0 Bearer tokens in two ways: 1. Using the `Authorization: Bearer <token>` header 2. Using the `access_token` query parameter Key changes: - Added support for extracting tokens from both Bearer headers and access_token query parameters - Implemented proper WWW-Authenticate headers for 401/403 responses with appropriate error descriptions - Added comprehensive test coverage for the new authentication methods - Updated the OAuth2 protected resource metadata endpoint to advertise Bearer token support - Enhanced the OAuth2 testing script to verify Bearer token functionality These authentication methods are added as fallback options, maintaining backward compatibility with Coder's existing authentication mechanisms. The existing authentication methods (cookies, session token header, etc.) still take precedence. This implementation follows the OAuth 2.0 Bearer Token specification (RFC 6750) and improves interoperability with standard OAuth 2.0 clients.
This commit is contained in:
@ -214,6 +214,31 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
// Add WWW-Authenticate header for 401/403 responses (RFC 6750)
|
||||
if code == http.StatusUnauthorized || code == http.StatusForbidden {
|
||||
var wwwAuth string
|
||||
|
||||
switch code {
|
||||
case http.StatusUnauthorized:
|
||||
// Map 401 to invalid_token with specific error descriptions
|
||||
switch {
|
||||
case strings.Contains(response.Message, "expired") || strings.Contains(response.Detail, "expired"):
|
||||
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token has expired"`
|
||||
case strings.Contains(response.Message, "audience") || strings.Contains(response.Message, "mismatch"):
|
||||
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token audience does not match this resource"`
|
||||
default:
|
||||
wwwAuth = `Bearer realm="coder", error="invalid_token", error_description="The access token is invalid"`
|
||||
}
|
||||
case http.StatusForbidden:
|
||||
// Map 403 to insufficient_scope per RFC 6750
|
||||
wwwAuth = `Bearer realm="coder", error="insufficient_scope", error_description="The request requires higher privileges than provided by the access token"`
|
||||
default:
|
||||
wwwAuth = `Bearer realm="coder"`
|
||||
}
|
||||
|
||||
rw.Header().Set("WWW-Authenticate", wwwAuth)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, code, response)
|
||||
return nil, nil, false
|
||||
}
|
||||
@ -653,9 +678,14 @@ func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, s
|
||||
// 1: The cookie
|
||||
// 2. The coder_session_token query parameter
|
||||
// 3. The custom auth header
|
||||
// 4. RFC 6750 Authorization: Bearer header
|
||||
// 5. RFC 6750 access_token query parameter
|
||||
//
|
||||
// API tokens for apps are read from workspaceapps/cookies.go.
|
||||
func APITokenFromRequest(r *http.Request) string {
|
||||
// Prioritize existing Coder custom authentication methods first
|
||||
// to maintain backward compatibility and existing behavior
|
||||
|
||||
cookie, err := r.Cookie(codersdk.SessionTokenCookie)
|
||||
if err == nil && cookie.Value != "" {
|
||||
return cookie.Value
|
||||
@ -671,7 +701,18 @@ func APITokenFromRequest(r *http.Request) string {
|
||||
return headerValue
|
||||
}
|
||||
|
||||
// TODO(ThomasK33): Implement RFC 6750
|
||||
// RFC 6750 Bearer Token support (added as fallback methods)
|
||||
// Check Authorization: Bearer <token> header (case-insensitive per RFC 6750)
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return authHeader[7:] // Skip "Bearer " (7 characters)
|
||||
}
|
||||
|
||||
// Check access_token query parameter
|
||||
accessToken := r.URL.Query().Get("access_token")
|
||||
if accessToken != "" {
|
||||
return accessToken
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
@ -102,6 +102,12 @@ func CSRF(cookieCfg codersdk.HTTPCookieConfig) func(next http.Handler) http.Hand
|
||||
return true
|
||||
}
|
||||
|
||||
// RFC 6750 Bearer Token authentication is exempt from CSRF
|
||||
// as it uses custom headers that cannot be set by malicious sites
|
||||
if authHeader := r.Header.Get("Authorization"); strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
|
||||
return true
|
||||
}
|
||||
|
||||
// If the X-CSRF-TOKEN header is set, we can exempt the func if it's valid.
|
||||
// This is the CSRF check.
|
||||
sent := r.Header.Get("X-CSRF-TOKEN")
|
||||
|
443
coderd/httpmw/rfc6750_extended_test.go
Normal file
443
coderd/httpmw/rfc6750_extended_test.go
Normal file
@ -0,0 +1,443 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestOAuth2BearerTokenSecurityBoundaries tests RFC 6750 security boundaries
|
||||
func TestOAuth2BearerTokenSecurityBoundaries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
// Create two different users with different API keys
|
||||
user1 := dbgen.User(t, db, database.User{})
|
||||
user2 := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create API keys for both users
|
||||
key1, token1 := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user1.ID,
|
||||
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
|
||||
})
|
||||
|
||||
_, token2 := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user2.ID,
|
||||
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
|
||||
})
|
||||
|
||||
t.Run("TokenIsolation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create middleware
|
||||
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
})
|
||||
|
||||
// Handler that returns the authenticated user ID
|
||||
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := httpmw.APIKey(r)
|
||||
rw.Header().Set("X-User-ID", apiKey.UserID.String())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test that user1's token only accesses user1's data
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.Header.Set("Authorization", "Bearer "+token1)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec1, req1)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec1.Code)
|
||||
require.Equal(t, user1.ID.String(), rec1.Header().Get("X-User-ID"))
|
||||
|
||||
// Test that user2's token only accesses user2's data
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.Header.Set("Authorization", "Bearer "+token2)
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec2, req2)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec2.Code)
|
||||
require.Equal(t, user2.ID.String(), rec2.Header().Get("X-User-ID"))
|
||||
|
||||
// Verify users can't access each other's data
|
||||
require.NotEqual(t, rec1.Header().Get("X-User-ID"), rec2.Header().Get("X-User-ID"))
|
||||
})
|
||||
|
||||
t.Run("CrossTokenAttempts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
})
|
||||
|
||||
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Try to use invalid token (should fail)
|
||||
invalidToken := key1.ID + "-invalid-secret"
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+invalidToken)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer")
|
||||
require.Contains(t, rec.Header().Get("WWW-Authenticate"), "invalid_token")
|
||||
})
|
||||
|
||||
t.Run("TimingAttackResistance", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
})
|
||||
|
||||
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test multiple invalid tokens to ensure consistent timing
|
||||
invalidTokens := []string{
|
||||
"invalid-token-1",
|
||||
"invalid-token-2-longer",
|
||||
"a",
|
||||
strings.Repeat("x", 100),
|
||||
}
|
||||
|
||||
times := make([]time.Duration, len(invalidTokens))
|
||||
|
||||
for i, token := range invalidTokens {
|
||||
start := time.Now()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
times[i] = time.Since(start)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
// While we can't guarantee perfect timing consistency in tests,
|
||||
// we can at least verify that the responses are all unauthorized
|
||||
// and contain proper WWW-Authenticate headers
|
||||
for _, token := range invalidTokens {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestOAuth2BearerTokenMalformedHeaders tests handling of malformed Bearer headers per RFC 6750
|
||||
func TestOAuth2BearerTokenMalformedHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
})
|
||||
|
||||
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectedStatus int
|
||||
shouldHaveWWW bool
|
||||
}{
|
||||
{
|
||||
name: "MissingBearer",
|
||||
authHeader: "invalid-token",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
shouldHaveWWW: true,
|
||||
},
|
||||
{
|
||||
name: "CaseSensitive",
|
||||
authHeader: "bearer token", // lowercase should still work
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
shouldHaveWWW: true,
|
||||
},
|
||||
{
|
||||
name: "ExtraSpaces",
|
||||
authHeader: "Bearer token-with-extra-spaces",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
shouldHaveWWW: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyToken",
|
||||
authHeader: "Bearer ",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
shouldHaveWWW: true,
|
||||
},
|
||||
{
|
||||
name: "OnlyBearer",
|
||||
authHeader: "Bearer",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
shouldHaveWWW: true,
|
||||
},
|
||||
{
|
||||
name: "MultipleBearer",
|
||||
authHeader: "Bearer token1 Bearer token2",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
shouldHaveWWW: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidBase64",
|
||||
authHeader: "Bearer !!!invalid-base64!!!",
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
shouldHaveWWW: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", test.authHeader)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, test.expectedStatus, rec.Code)
|
||||
|
||||
if test.shouldHaveWWW {
|
||||
wwwAuth := rec.Header().Get("WWW-Authenticate")
|
||||
require.Contains(t, wwwAuth, "Bearer")
|
||||
require.Contains(t, wwwAuth, "realm=\"coder\"")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuth2BearerTokenPrecedence tests token extraction precedence per RFC 6750
|
||||
func TestOAuth2BearerTokenPrecedence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create a valid API key
|
||||
key, validToken := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
|
||||
})
|
||||
|
||||
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
})
|
||||
|
||||
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := httpmw.APIKey(r)
|
||||
rw.Header().Set("X-Key-ID", apiKey.ID)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
t.Run("CookieTakesPrecedenceOverBearer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
// Set both cookie and Bearer header - cookie should take precedence
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: validToken,
|
||||
})
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
|
||||
})
|
||||
|
||||
t.Run("QueryParameterTakesPrecedenceOverBearer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Set both query parameter and Bearer header - query should take precedence
|
||||
u, _ := url.Parse("/test")
|
||||
q := u.Query()
|
||||
q.Set(codersdk.SessionTokenCookie, validToken)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req := httptest.NewRequest("GET", u.String(), nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
|
||||
})
|
||||
|
||||
t.Run("BearerHeaderFallback", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Only set Bearer header - should be used as fallback
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+validToken)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
|
||||
})
|
||||
|
||||
t.Run("AccessTokenQueryParameterFallback", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Only set access_token query parameter - should be used as fallback
|
||||
u, _ := url.Parse("/test")
|
||||
q := u.Query()
|
||||
q.Set("access_token", validToken)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req := httptest.NewRequest("GET", u.String(), nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
|
||||
})
|
||||
|
||||
t.Run("MultipleAuthMethodsShouldNotConflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// RFC 6750 says clients shouldn't send tokens in multiple ways,
|
||||
// but if they do, we should handle it gracefully by using precedence
|
||||
u, _ := url.Parse("/test")
|
||||
q := u.Query()
|
||||
q.Set("access_token", validToken)
|
||||
q.Set(codersdk.SessionTokenCookie, validToken)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req := httptest.NewRequest("GET", u.String(), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+validToken)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: validToken,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
// Should succeed using the highest precedence method (cookie)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestOAuth2WWWAuthenticateCompliance tests WWW-Authenticate header compliance with RFC 6750
|
||||
func TestOAuth2WWWAuthenticateCompliance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
})
|
||||
|
||||
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
t.Run("UnauthorizedResponse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
|
||||
wwwAuth := rec.Header().Get("WWW-Authenticate")
|
||||
require.NotEmpty(t, wwwAuth)
|
||||
|
||||
// RFC 6750 requires specific format: Bearer realm="realm"
|
||||
require.Contains(t, wwwAuth, "Bearer")
|
||||
require.Contains(t, wwwAuth, "realm=\"coder\"")
|
||||
require.Contains(t, wwwAuth, "error=\"invalid_token\"")
|
||||
require.Contains(t, wwwAuth, "error_description=")
|
||||
})
|
||||
|
||||
t.Run("ExpiredTokenResponse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create an expired API key
|
||||
_, expiredToken := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
ExpiresAt: dbtime.Now().Add(-time.Hour), // Expired 1 hour ago
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+expiredToken)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
|
||||
wwwAuth := rec.Header().Get("WWW-Authenticate")
|
||||
require.Contains(t, wwwAuth, "Bearer")
|
||||
require.Contains(t, wwwAuth, "realm=\"coder\"")
|
||||
require.Contains(t, wwwAuth, "error=\"invalid_token\"")
|
||||
require.Contains(t, wwwAuth, "error_description=\"The access token has expired\"")
|
||||
})
|
||||
|
||||
t.Run("InsufficientScopeResponse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// For this test, we'll test with an invalid token to trigger the middleware's
|
||||
// error handling which does set WWW-Authenticate headers for 403 responses
|
||||
// In practice, insufficient scope errors would be handled by RBAC middleware
|
||||
// that comes after authentication, but we can simulate a 403 from the auth middleware
|
||||
|
||||
req := httptest.NewRequest("GET", "/admin", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token-that-triggers-403")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Use a middleware configuration that might trigger a 403 instead of 401
|
||||
// for certain types of authentication failures
|
||||
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
})
|
||||
|
||||
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
// This shouldn't be reached due to auth failure
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
// This will be a 401 (unauthorized) rather than 403 (forbidden) for invalid tokens
|
||||
// which is correct - 403 would come from RBAC after successful authentication
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
|
||||
wwwAuth := rec.Header().Get("WWW-Authenticate")
|
||||
require.Contains(t, wwwAuth, "Bearer")
|
||||
require.Contains(t, wwwAuth, "realm=\"coder\"")
|
||||
require.Contains(t, wwwAuth, "error=\"invalid_token\"")
|
||||
require.Contains(t, wwwAuth, "error_description=")
|
||||
})
|
||||
}
|
241
coderd/httpmw/rfc6750_test.go
Normal file
241
coderd/httpmw/rfc6750_test.go
Normal file
@ -0,0 +1,241 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestRFC6750BearerTokenAuthentication tests that RFC 6750 bearer tokens work correctly
|
||||
// for authentication, including both Authorization header and access_token query parameter methods.
|
||||
func TestRFC6750BearerTokenAuthentication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
// Create a test user and API key
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create an OAuth2 provider app token (which should work with bearer token authentication)
|
||||
key, token := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
|
||||
})
|
||||
|
||||
cfg := httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
}
|
||||
|
||||
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := httpmw.APIKey(r)
|
||||
require.Equal(t, key.ID, apiKey.ID)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("AuthorizationBearerHeader", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
})
|
||||
|
||||
t.Run("AccessTokenQueryParameter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test?access_token="+url.QueryEscape(token), nil)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
})
|
||||
|
||||
t.Run("BearerTokenPriorityAfterCustomMethods", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a different token for custom header
|
||||
customKey, customToken := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
|
||||
})
|
||||
|
||||
// Create handler that checks which token was used
|
||||
priorityHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKey := httpmw.APIKey(r)
|
||||
// Should use the custom header token, not the bearer token
|
||||
require.Equal(t, customKey.ID, apiKey.ID)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
// Set both custom header and bearer header - custom should win
|
||||
req.Header.Set(codersdk.SessionTokenHeader, customToken)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
httpmw.ExtractAPIKeyMW(cfg)(priorityHandler).ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
})
|
||||
|
||||
t.Run("InvalidBearerToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Code)
|
||||
|
||||
// Check that WWW-Authenticate header is present
|
||||
wwwAuth := rw.Header().Get("WWW-Authenticate")
|
||||
require.NotEmpty(t, wwwAuth)
|
||||
require.Contains(t, wwwAuth, "Bearer")
|
||||
require.Contains(t, wwwAuth, `realm="coder"`)
|
||||
require.Contains(t, wwwAuth, "invalid_token")
|
||||
})
|
||||
|
||||
t.Run("ExpiredBearerToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create an expired token
|
||||
_, expiredToken := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
ExpiresAt: dbtime.Now().Add(-testutil.WaitShort), // Expired
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+expiredToken)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Code)
|
||||
|
||||
// Check that WWW-Authenticate header contains expired error
|
||||
wwwAuth := rw.Header().Get("WWW-Authenticate")
|
||||
require.NotEmpty(t, wwwAuth)
|
||||
require.Contains(t, wwwAuth, "Bearer")
|
||||
require.Contains(t, wwwAuth, `realm="coder"`)
|
||||
require.Contains(t, wwwAuth, "expired")
|
||||
})
|
||||
|
||||
t.Run("MissingBearerToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
// No authentication provided
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
httpmw.ExtractAPIKeyMW(cfg)(testHandler).ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rw.Code)
|
||||
|
||||
// Check that WWW-Authenticate header is present
|
||||
wwwAuth := rw.Header().Get("WWW-Authenticate")
|
||||
require.NotEmpty(t, wwwAuth)
|
||||
require.Contains(t, wwwAuth, "Bearer")
|
||||
require.Contains(t, wwwAuth, `realm="coder"`)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAPITokenFromRequest tests the RFC 6750 bearer token extraction directly
|
||||
func TestAPITokenFromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
token := "test-token-value"
|
||||
customToken := "custom-token"
|
||||
cookieToken := "cookie-token"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupReq func(*http.Request)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "AuthorizationBearerHeader",
|
||||
setupReq: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expected: token,
|
||||
},
|
||||
{
|
||||
name: "AccessTokenQueryParameter",
|
||||
setupReq: func(req *http.Request) {
|
||||
q := req.URL.Query()
|
||||
q.Set("access_token", token)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
},
|
||||
expected: token,
|
||||
},
|
||||
{
|
||||
name: "CustomMethodsPriorityOverBearer",
|
||||
setupReq: func(req *http.Request) {
|
||||
req.Header.Set(codersdk.SessionTokenHeader, customToken)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expected: customToken,
|
||||
},
|
||||
{
|
||||
name: "CookiePriorityOverBearer",
|
||||
setupReq: func(req *http.Request) {
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: cookieToken,
|
||||
})
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
},
|
||||
expected: cookieToken,
|
||||
},
|
||||
{
|
||||
name: "NoTokenReturnsEmpty",
|
||||
setupReq: func(req *http.Request) {
|
||||
// No authentication provided
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "InvalidAuthorizationHeaderIgnored",
|
||||
setupReq: func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") // Basic auth, not Bearer
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
tt.setupReq(req)
|
||||
|
||||
extractedToken := httpmw.APITokenFromRequest(req)
|
||||
require.Equal(t, tt.expected, extractedToken)
|
||||
})
|
||||
}
|
||||
}
|
@ -431,9 +431,8 @@ func (api *API) oauth2ProtectedResourceMetadata(rw http.ResponseWriter, r *http.
|
||||
AuthorizationServers: []string{api.AccessURL.String()},
|
||||
// TODO: Implement scope system based on RBAC permissions
|
||||
ScopesSupported: []string{},
|
||||
// Note: Coder uses custom authentication methods, not RFC 6750 bearer tokens
|
||||
// TODO(ThomasK33): Implement RFC 6750
|
||||
// BearerMethodsSupported: []string{}, // Omitted - no standard bearer token support
|
||||
// RFC 6750 Bearer Token methods supported as fallback methods in api key middleware
|
||||
BearerMethodsSupported: []string{"header", "query"},
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, metadata)
|
||||
}
|
||||
|
@ -77,9 +77,9 @@ func TestOAuth2ProtectedResourceMetadata(t *testing.T) {
|
||||
require.NotEmpty(t, metadata.AuthorizationServers)
|
||||
require.Len(t, metadata.AuthorizationServers, 1)
|
||||
require.Equal(t, metadata.Resource, metadata.AuthorizationServers[0])
|
||||
// BearerMethodsSupported is omitted since Coder uses custom authentication methods
|
||||
// Standard RFC 6750 bearer tokens are not supported
|
||||
require.True(t, len(metadata.BearerMethodsSupported) == 0)
|
||||
// RFC 6750 bearer tokens are now supported as fallback methods
|
||||
require.Contains(t, metadata.BearerMethodsSupported, "header")
|
||||
require.Contains(t, metadata.BearerMethodsSupported, "query")
|
||||
// ScopesSupported can be empty until scope system is implemented
|
||||
// Empty slice is marshaled as empty array, but can be nil when unmarshaled
|
||||
require.True(t, len(metadata.ScopesSupported) == 0)
|
||||
|
@ -170,6 +170,53 @@ else
|
||||
echo -e "${RED}✗ Token refresh failed${NC}\n"
|
||||
fi
|
||||
|
||||
# Test 6: RFC 6750 Bearer Token Authentication
|
||||
echo -e "${YELLOW}Test 6: RFC 6750 Bearer Token Authentication${NC}"
|
||||
ACCESS_TOKEN=$(echo "$TOKEN_RESPONSE" | jq -r '.access_token')
|
||||
|
||||
# Test Authorization: Bearer header
|
||||
echo -e "${BLUE}Testing Authorization: Bearer header...${NC}"
|
||||
BEARER_RESPONSE=$(curl -s -w "%{http_code}" "$BASE_URL/api/v2/users/me" \
|
||||
-H "Authorization: Bearer $ACCESS_TOKEN")
|
||||
|
||||
HTTP_CODE="${BEARER_RESPONSE: -3}"
|
||||
if [ "$HTTP_CODE" = "200" ]; then
|
||||
echo -e "${GREEN}✓ Authorization: Bearer header working${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ Authorization: Bearer header failed (HTTP $HTTP_CODE)${NC}"
|
||||
fi
|
||||
|
||||
# Test access_token query parameter
|
||||
echo -e "${BLUE}Testing access_token query parameter...${NC}"
|
||||
QUERY_RESPONSE=$(curl -s -w "%{http_code}" "$BASE_URL/api/v2/users/me?access_token=$ACCESS_TOKEN")
|
||||
|
||||
HTTP_CODE="${QUERY_RESPONSE: -3}"
|
||||
if [ "$HTTP_CODE" = "200" ]; then
|
||||
echo -e "${GREEN}✓ access_token query parameter working${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ access_token query parameter failed (HTTP $HTTP_CODE)${NC}"
|
||||
fi
|
||||
|
||||
# Test WWW-Authenticate header on unauthorized request
|
||||
echo -e "${BLUE}Testing WWW-Authenticate header on 401...${NC}"
|
||||
UNAUTH_RESPONSE=$(curl -s -I "$BASE_URL/api/v2/users/me")
|
||||
if echo "$UNAUTH_RESPONSE" | grep -i "WWW-Authenticate.*Bearer" >/dev/null; then
|
||||
echo -e "${GREEN}✓ WWW-Authenticate header present${NC}"
|
||||
else
|
||||
echo -e "${RED}✗ WWW-Authenticate header missing${NC}"
|
||||
fi
|
||||
|
||||
# Test 7: Protected Resource Metadata
|
||||
echo -e "${YELLOW}Test 7: Protected Resource Metadata (RFC 9728)${NC}"
|
||||
PROTECTED_METADATA=$(curl -s "$BASE_URL/.well-known/oauth-protected-resource")
|
||||
echo "$PROTECTED_METADATA" | jq .
|
||||
|
||||
if echo "$PROTECTED_METADATA" | jq -e '.bearer_methods_supported[]' | grep -q "header"; then
|
||||
echo -e "${GREEN}✓ Protected Resource Metadata indicates bearer token support${NC}\n"
|
||||
else
|
||||
echo -e "${RED}✗ Protected Resource Metadata missing bearer token support${NC}\n"
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
echo -e "${YELLOW}Cleaning up...${NC}"
|
||||
curl -s -X DELETE "$BASE_URL/api/v2/oauth2-provider/apps/$CLIENT_ID" \
|
||||
|
Reference in New Issue
Block a user