Files
coder/coderd/httpmw/rfc6750_extended_test.go
Thomas Kosiewski 09c50559f3 feat: implement RFC 6750 Bearer token authentication (#18644)
# Add RFC 6750 Bearer Token Authentication Support

This PR implements RFC 6750 Bearer Token authentication as an additional authentication method for Coder's API. This allows clients to authenticate using standard OAuth 2.0 Bearer tokens in two ways:

1. Using the `Authorization: Bearer <token>` header
2. Using the `access_token` query parameter

Key changes:

- Added support for extracting tokens from both Bearer headers and access_token query parameters
- Implemented proper WWW-Authenticate headers for 401/403 responses with appropriate error descriptions
- Added comprehensive test coverage for the new authentication methods
- Updated the OAuth2 protected resource metadata endpoint to advertise Bearer token support
- Enhanced the OAuth2 testing script to verify Bearer token functionality

These authentication methods are added as fallback options, maintaining backward compatibility with Coder's existing authentication mechanisms. The existing authentication methods (cookies, session token header, etc.) still take precedence.

This implementation follows the OAuth 2.0 Bearer Token specification (RFC 6750) and improves interoperability with standard OAuth 2.0 clients.
2025-07-02 19:14:54 +02:00

444 lines
13 KiB
Go

package httpmw_test
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// TestOAuth2BearerTokenSecurityBoundaries tests RFC 6750 security boundaries
func TestOAuth2BearerTokenSecurityBoundaries(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
// Create two different users with different API keys
user1 := dbgen.User(t, db, database.User{})
user2 := dbgen.User(t, db, database.User{})
// Create API keys for both users
key1, token1 := dbgen.APIKey(t, db, database.APIKey{
UserID: user1.ID,
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
})
_, token2 := dbgen.APIKey(t, db, database.APIKey{
UserID: user2.ID,
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
})
t.Run("TokenIsolation", func(t *testing.T) {
t.Parallel()
// Create middleware
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
})
// Handler that returns the authenticated user ID
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
apiKey := httpmw.APIKey(r)
rw.Header().Set("X-User-ID", apiKey.UserID.String())
rw.WriteHeader(http.StatusOK)
}))
// Test that user1's token only accesses user1's data
req1 := httptest.NewRequest("GET", "/test", nil)
req1.Header.Set("Authorization", "Bearer "+token1)
rec1 := httptest.NewRecorder()
handler.ServeHTTP(rec1, req1)
require.Equal(t, http.StatusOK, rec1.Code)
require.Equal(t, user1.ID.String(), rec1.Header().Get("X-User-ID"))
// Test that user2's token only accesses user2's data
req2 := httptest.NewRequest("GET", "/test", nil)
req2.Header.Set("Authorization", "Bearer "+token2)
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
require.Equal(t, http.StatusOK, rec2.Code)
require.Equal(t, user2.ID.String(), rec2.Header().Get("X-User-ID"))
// Verify users can't access each other's data
require.NotEqual(t, rec1.Header().Get("X-User-ID"), rec2.Header().Get("X-User-ID"))
})
t.Run("CrossTokenAttempts", func(t *testing.T) {
t.Parallel()
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
})
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
// Try to use invalid token (should fail)
invalidToken := key1.ID + "-invalid-secret"
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+invalidToken)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer")
require.Contains(t, rec.Header().Get("WWW-Authenticate"), "invalid_token")
})
t.Run("TimingAttackResistance", func(t *testing.T) {
t.Parallel()
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
})
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
// Test multiple invalid tokens to ensure consistent timing
invalidTokens := []string{
"invalid-token-1",
"invalid-token-2-longer",
"a",
strings.Repeat("x", 100),
}
times := make([]time.Duration, len(invalidTokens))
for i, token := range invalidTokens {
start := time.Now()
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
times[i] = time.Since(start)
require.Equal(t, http.StatusUnauthorized, rec.Code)
}
// While we can't guarantee perfect timing consistency in tests,
// we can at least verify that the responses are all unauthorized
// and contain proper WWW-Authenticate headers
for _, token := range invalidTokens {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
require.Contains(t, rec.Header().Get("WWW-Authenticate"), "Bearer")
}
})
}
// TestOAuth2BearerTokenMalformedHeaders tests handling of malformed Bearer headers per RFC 6750
func TestOAuth2BearerTokenMalformedHeaders(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
})
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
tests := []struct {
name string
authHeader string
expectedStatus int
shouldHaveWWW bool
}{
{
name: "MissingBearer",
authHeader: "invalid-token",
expectedStatus: http.StatusUnauthorized,
shouldHaveWWW: true,
},
{
name: "CaseSensitive",
authHeader: "bearer token", // lowercase should still work
expectedStatus: http.StatusUnauthorized,
shouldHaveWWW: true,
},
{
name: "ExtraSpaces",
authHeader: "Bearer token-with-extra-spaces",
expectedStatus: http.StatusUnauthorized,
shouldHaveWWW: true,
},
{
name: "EmptyToken",
authHeader: "Bearer ",
expectedStatus: http.StatusUnauthorized,
shouldHaveWWW: true,
},
{
name: "OnlyBearer",
authHeader: "Bearer",
expectedStatus: http.StatusUnauthorized,
shouldHaveWWW: true,
},
{
name: "MultipleBearer",
authHeader: "Bearer token1 Bearer token2",
expectedStatus: http.StatusUnauthorized,
shouldHaveWWW: true,
},
{
name: "InvalidBase64",
authHeader: "Bearer !!!invalid-base64!!!",
expectedStatus: http.StatusUnauthorized,
shouldHaveWWW: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", test.authHeader)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, test.expectedStatus, rec.Code)
if test.shouldHaveWWW {
wwwAuth := rec.Header().Get("WWW-Authenticate")
require.Contains(t, wwwAuth, "Bearer")
require.Contains(t, wwwAuth, "realm=\"coder\"")
}
})
}
}
// TestOAuth2BearerTokenPrecedence tests token extraction precedence per RFC 6750
func TestOAuth2BearerTokenPrecedence(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
// Create a valid API key
key, validToken := dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
ExpiresAt: dbtime.Now().Add(testutil.WaitLong),
})
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
})
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
apiKey := httpmw.APIKey(r)
rw.Header().Set("X-Key-ID", apiKey.ID)
rw.WriteHeader(http.StatusOK)
}))
t.Run("CookieTakesPrecedenceOverBearer", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/test", nil)
// Set both cookie and Bearer header - cookie should take precedence
req.AddCookie(&http.Cookie{
Name: codersdk.SessionTokenCookie,
Value: validToken,
})
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
})
t.Run("QueryParameterTakesPrecedenceOverBearer", func(t *testing.T) {
t.Parallel()
// Set both query parameter and Bearer header - query should take precedence
u, _ := url.Parse("/test")
q := u.Query()
q.Set(codersdk.SessionTokenCookie, validToken)
u.RawQuery = q.Encode()
req := httptest.NewRequest("GET", u.String(), nil)
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
})
t.Run("BearerHeaderFallback", func(t *testing.T) {
t.Parallel()
// Only set Bearer header - should be used as fallback
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+validToken)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
})
t.Run("AccessTokenQueryParameterFallback", func(t *testing.T) {
t.Parallel()
// Only set access_token query parameter - should be used as fallback
u, _ := url.Parse("/test")
q := u.Query()
q.Set("access_token", validToken)
u.RawQuery = q.Encode()
req := httptest.NewRequest("GET", u.String(), nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
})
t.Run("MultipleAuthMethodsShouldNotConflict", func(t *testing.T) {
t.Parallel()
// RFC 6750 says clients shouldn't send tokens in multiple ways,
// but if they do, we should handle it gracefully by using precedence
u, _ := url.Parse("/test")
q := u.Query()
q.Set("access_token", validToken)
q.Set(codersdk.SessionTokenCookie, validToken)
u.RawQuery = q.Encode()
req := httptest.NewRequest("GET", u.String(), nil)
req.Header.Set("Authorization", "Bearer "+validToken)
req.AddCookie(&http.Cookie{
Name: codersdk.SessionTokenCookie,
Value: validToken,
})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
// Should succeed using the highest precedence method (cookie)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, key.ID, rec.Header().Get("X-Key-ID"))
})
}
// TestOAuth2WWWAuthenticateCompliance tests WWW-Authenticate header compliance with RFC 6750
func TestOAuth2WWWAuthenticateCompliance(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
})
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
t.Run("UnauthorizedResponse", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
wwwAuth := rec.Header().Get("WWW-Authenticate")
require.NotEmpty(t, wwwAuth)
// RFC 6750 requires specific format: Bearer realm="realm"
require.Contains(t, wwwAuth, "Bearer")
require.Contains(t, wwwAuth, "realm=\"coder\"")
require.Contains(t, wwwAuth, "error=\"invalid_token\"")
require.Contains(t, wwwAuth, "error_description=")
})
t.Run("ExpiredTokenResponse", func(t *testing.T) {
t.Parallel()
// Create an expired API key
_, expiredToken := dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
ExpiresAt: dbtime.Now().Add(-time.Hour), // Expired 1 hour ago
})
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+expiredToken)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
wwwAuth := rec.Header().Get("WWW-Authenticate")
require.Contains(t, wwwAuth, "Bearer")
require.Contains(t, wwwAuth, "realm=\"coder\"")
require.Contains(t, wwwAuth, "error=\"invalid_token\"")
require.Contains(t, wwwAuth, "error_description=\"The access token has expired\"")
})
t.Run("InsufficientScopeResponse", func(t *testing.T) {
t.Parallel()
// For this test, we'll test with an invalid token to trigger the middleware's
// error handling which does set WWW-Authenticate headers for 403 responses
// In practice, insufficient scope errors would be handled by RBAC middleware
// that comes after authentication, but we can simulate a 403 from the auth middleware
req := httptest.NewRequest("GET", "/admin", nil)
req.Header.Set("Authorization", "Bearer invalid-token-that-triggers-403")
rec := httptest.NewRecorder()
// Use a middleware configuration that might trigger a 403 instead of 401
// for certain types of authentication failures
middleware := httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
})
handler := middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// This shouldn't be reached due to auth failure
rw.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(rec, req)
// This will be a 401 (unauthorized) rather than 403 (forbidden) for invalid tokens
// which is correct - 403 would come from RBAC after successful authentication
require.Equal(t, http.StatusUnauthorized, rec.Code)
wwwAuth := rec.Header().Get("WWW-Authenticate")
require.Contains(t, wwwAuth, "Bearer")
require.Contains(t, wwwAuth, "realm=\"coder\"")
require.Contains(t, wwwAuth, "error=\"invalid_token\"")
require.Contains(t, wwwAuth, "error_description=")
})
}