mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
feat: oauth2 - add RFC 8707 resource indicators and audience validation (#18575)
This pull request implements RFC 8707, Resource Indicators for OAuth 2.0 (https://datatracker.ietf.org/doc/html/rfc8707), to enhance the security of our OAuth 2.0 provider. This change enables proper audience validation and binds access tokens to their intended resource, which is crucial for preventing token misuse in multi-tenant environments or deployments with multiple resource servers. ## Key Changes: * Resource Parameter Support: Adds support for the resource parameter in both the authorization (`/oauth2/authorize`) and token (`/oauth2/token`) endpoints, allowing clients to specify the intended resource server. * Audience Validation: Implements server-side validation to ensure that the resource parameter provided during the token exchange matches the one from the authorization request. * API Middleware Enforcement: Introduces a new validation step in the API authentication middleware (`coderd/httpmw/apikey.go`) to verify that the audience of the access token matches the resource server being accessed. * Database Schema Updates: * Adds a `resource_uri` column to the `oauth2_provider_app_codes` table to store the resource requested during authorization. * Adds an `audience` column to the `oauth2_provider_app_tokens` table to bind the issued token to a specific audience. * Enhanced PKCE: Includes a minor enhancement to the PKCE implementation to protect against timing attacks. * Comprehensive Testing: Adds extensive new tests to `coderd/oauth2_test.go` to cover various RFC 8707 scenarios, including valid flows, mismatched resources, and refresh token validation. ## How it Works: 1. An OAuth2 client specifies the target resource (e.g., https://coder.example.com) using the resource parameter in the authorization request. 2. The authorization server stores this resource URI with the authorization code. 3. During the token exchange, the server validates that the client provides the same resource parameter. 4. The server issues an access token with an audience claim set to the validated resource URI. 5. When the client uses the access token to call an API endpoint, the middleware verifies that the token's audience matches the URL of the Coder deployment, rejecting any tokens intended for a different resource. This ensures that a token issued for one Coder deployment cannot be used to access another, significantly strengthening our authentication security. --- Change-Id: I3924cb2139e837e3ac0b0bd40a5aeb59637ebc1b Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
@ -15,9 +15,11 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/net/idna"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
@ -110,6 +112,9 @@ type ExtractAPIKeyConfig struct {
|
||||
// This is originally implemented to send entitlement warning headers after
|
||||
// a user is authenticated to prevent additional CLI invocations.
|
||||
PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header)
|
||||
|
||||
// Logger is used for logging middleware operations.
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
// ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request,
|
||||
@ -240,6 +245,17 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
})
|
||||
}
|
||||
|
||||
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
|
||||
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
|
||||
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil {
|
||||
// Log the detailed error for debugging but don't expose it to the client
|
||||
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
|
||||
return optionalWrite(http.StatusForbidden, codersdk.Response{
|
||||
Message: "Token audience validation failed",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
|
||||
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
|
||||
// refreshing the OIDC token.
|
||||
@ -446,6 +462,160 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
|
||||
return key, &actor, true
|
||||
}
|
||||
|
||||
// validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token
|
||||
// is being used with the correct audience/resource server (RFC 8707).
|
||||
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error {
|
||||
// Get the OAuth2 provider app token to check its audience
|
||||
//nolint:gocritic // System needs to access token for audience validation
|
||||
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get OAuth2 token: %w", err)
|
||||
}
|
||||
|
||||
// If no audience is set, allow the request (for backward compatibility)
|
||||
if !token.Audience.Valid || token.Audience.String == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract the expected audience from the request
|
||||
expectedAudience := extractExpectedAudience(r)
|
||||
|
||||
// Normalize both audience values for RFC 3986 compliant comparison
|
||||
normalizedTokenAudience := normalizeAudienceURI(token.Audience.String)
|
||||
normalizedExpectedAudience := normalizeAudienceURI(expectedAudience)
|
||||
|
||||
// Validate that the token's audience matches the expected audience
|
||||
if normalizedTokenAudience != normalizedExpectedAudience {
|
||||
return xerrors.Errorf("token audience %q does not match expected audience %q",
|
||||
token.Audience.String, expectedAudience)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeAudienceURI implements RFC 3986 URI normalization for OAuth2 audience comparison.
|
||||
// This ensures consistent audience matching between authorization and token validation.
|
||||
func normalizeAudienceURI(audienceURI string) string {
|
||||
if audienceURI == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
u, err := url.Parse(audienceURI)
|
||||
if err != nil {
|
||||
// If parsing fails, return as-is to avoid breaking existing functionality
|
||||
return audienceURI
|
||||
}
|
||||
|
||||
// Apply RFC 3986 syntax-based normalization:
|
||||
|
||||
// 1. Scheme normalization - case-insensitive
|
||||
u.Scheme = strings.ToLower(u.Scheme)
|
||||
|
||||
// 2. Host normalization - case-insensitive and IDN (punnycode) normalization
|
||||
u.Host = normalizeHost(u.Host)
|
||||
|
||||
// 3. Remove default ports for HTTP/HTTPS
|
||||
if (u.Scheme == "http" && strings.HasSuffix(u.Host, ":80")) ||
|
||||
(u.Scheme == "https" && strings.HasSuffix(u.Host, ":443")) {
|
||||
// Extract host without default port
|
||||
if idx := strings.LastIndex(u.Host, ":"); idx > 0 {
|
||||
u.Host = u.Host[:idx]
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Path normalization including dot-segment removal (RFC 3986 Section 6.2.2.3)
|
||||
u.Path = normalizePathSegments(u.Path)
|
||||
|
||||
// 5. Remove fragment - should already be empty due to earlier validation,
|
||||
// but clear it as a safety measure in case validation was bypassed
|
||||
if u.Fragment != "" {
|
||||
// This should not happen if validation is working correctly
|
||||
u.Fragment = ""
|
||||
}
|
||||
|
||||
// 6. Keep query parameters as-is (rarely used in audience URIs but preserved for compatibility)
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// normalizeHost performs host normalization including case-insensitive conversion
|
||||
// and IDN (Internationalized Domain Name) punnycode normalization.
|
||||
func normalizeHost(host string) string {
|
||||
if host == "" {
|
||||
return host
|
||||
}
|
||||
|
||||
// Handle IPv6 addresses - they are enclosed in brackets
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
// IPv6 addresses should be normalized to lowercase
|
||||
return strings.ToLower(host)
|
||||
}
|
||||
|
||||
// Extract port if present
|
||||
var port string
|
||||
if idx := strings.LastIndex(host, ":"); idx > 0 {
|
||||
// Check if this is actually a port (not part of IPv6)
|
||||
if !strings.Contains(host[idx+1:], ":") {
|
||||
port = host[idx:]
|
||||
host = host[:idx]
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to lowercase for case-insensitive comparison
|
||||
host = strings.ToLower(host)
|
||||
|
||||
// Apply IDN normalization - convert Unicode domain names to ASCII (punnycode)
|
||||
if normalizedHost, err := idna.ToASCII(host); err == nil {
|
||||
host = normalizedHost
|
||||
}
|
||||
// If IDN conversion fails, continue with lowercase version
|
||||
|
||||
return host + port
|
||||
}
|
||||
|
||||
// normalizePathSegments normalizes path segments for consistent OAuth2 audience matching.
|
||||
// Uses url.URL.ResolveReference() which implements RFC 3986 dot-segment removal.
|
||||
func normalizePathSegments(path string) string {
|
||||
if path == "" {
|
||||
// If no path is specified, use "/" for consistency with RFC 8707 examples
|
||||
return "/"
|
||||
}
|
||||
|
||||
// Use url.URL.ResolveReference() to handle dot-segment removal per RFC 3986
|
||||
base := &url.URL{Path: "/"}
|
||||
ref := &url.URL{Path: path}
|
||||
resolved := base.ResolveReference(ref)
|
||||
|
||||
normalizedPath := resolved.Path
|
||||
|
||||
// Remove trailing slash from paths longer than "/" to normalize
|
||||
// This ensures "/api/" and "/api" are treated as equivalent
|
||||
if len(normalizedPath) > 1 && strings.HasSuffix(normalizedPath, "/") {
|
||||
normalizedPath = strings.TrimSuffix(normalizedPath, "/")
|
||||
}
|
||||
|
||||
return normalizedPath
|
||||
}
|
||||
|
||||
// Test export functions for testing package access
|
||||
|
||||
// extractExpectedAudience determines the expected audience for the current request.
|
||||
// This should match the resource parameter used during authorization.
|
||||
func extractExpectedAudience(r *http.Request) string {
|
||||
// For MCP compliance, the audience should be the canonical URI of the resource server
|
||||
// This typically matches the access URL of the Coder deployment
|
||||
scheme := "https"
|
||||
if r.TLS == nil {
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
// Use the Host header to construct the canonical audience URI
|
||||
audience := fmt.Sprintf("%s://%s", scheme, r.Host)
|
||||
|
||||
// Normalize the URI according to RFC 3986 for consistent comparison
|
||||
return normalizeAudienceURI(audience)
|
||||
}
|
||||
|
||||
// UserRBACSubject fetches a user's rbac.Subject from the database. It pulls all roles from both
|
||||
// site and organization scopes. It also pulls the groups, and the user's status.
|
||||
func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, scope rbac.ExpandableScope) (rbac.Subject, database.UserStatus, error) {
|
||||
|
@ -53,3 +53,213 @@ func TestParseUUID_Invalid(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response.Message, `Invalid UUID "wrong-id"`)
|
||||
}
|
||||
|
||||
// TestNormalizeAudienceURI tests URI normalization for OAuth2 audience validation
|
||||
func TestNormalizeAudienceURI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "SimpleHTTPWithoutTrailingSlash",
|
||||
input: "http://example.com",
|
||||
expected: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "SimpleHTTPWithTrailingSlash",
|
||||
input: "http://example.com/",
|
||||
expected: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "HTTPSWithPath",
|
||||
input: "https://api.example.com/v1/",
|
||||
expected: "https://api.example.com/v1",
|
||||
},
|
||||
{
|
||||
name: "CaseNormalization",
|
||||
input: "HTTPS://API.EXAMPLE.COM/V1/",
|
||||
expected: "https://api.example.com/V1",
|
||||
},
|
||||
{
|
||||
name: "DefaultHTTPPort",
|
||||
input: "http://example.com:80/api/",
|
||||
expected: "http://example.com/api",
|
||||
},
|
||||
{
|
||||
name: "DefaultHTTPSPort",
|
||||
input: "https://example.com:443/api/",
|
||||
expected: "https://example.com/api",
|
||||
},
|
||||
{
|
||||
name: "NonDefaultPort",
|
||||
input: "http://example.com:8080/api/",
|
||||
expected: "http://example.com:8080/api",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := normalizeAudienceURI(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizeHost tests host normalization including IDN support
|
||||
func TestNormalizeHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "SimpleHost",
|
||||
input: "example.com",
|
||||
expected: "example.com",
|
||||
},
|
||||
{
|
||||
name: "HostWithPort",
|
||||
input: "example.com:8080",
|
||||
expected: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "CaseNormalization",
|
||||
input: "EXAMPLE.COM",
|
||||
expected: "example.com",
|
||||
},
|
||||
{
|
||||
name: "IPv4Address",
|
||||
input: "192.168.1.1",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6Address",
|
||||
input: "[::1]:8080",
|
||||
expected: "[::1]:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := normalizeHost(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNormalizePathSegments tests path normalization including dot-segment removal
|
||||
func TestNormalizePathSegments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: "",
|
||||
expected: "/",
|
||||
},
|
||||
{
|
||||
name: "SimplePath",
|
||||
input: "/api/v1",
|
||||
expected: "/api/v1",
|
||||
},
|
||||
{
|
||||
name: "PathWithDotSegments",
|
||||
input: "/api/../v1/./test",
|
||||
expected: "/v1/test",
|
||||
},
|
||||
{
|
||||
name: "TrailingSlash",
|
||||
input: "/api/v1/",
|
||||
expected: "/api/v1",
|
||||
},
|
||||
{
|
||||
name: "MultipleSlashes",
|
||||
input: "/api//v1///test",
|
||||
expected: "/api//v1///test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := normalizePathSegments(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractExpectedAudience tests audience extraction from HTTP requests
|
||||
func TestExtractExpectedAudience(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SimpleHTTP",
|
||||
scheme: "http",
|
||||
host: "example.com",
|
||||
path: "/api/test",
|
||||
expected: "http://example.com/",
|
||||
},
|
||||
{
|
||||
name: "HTTPS",
|
||||
scheme: "https",
|
||||
host: "api.example.com",
|
||||
path: "/v1/users",
|
||||
expected: "https://api.example.com/",
|
||||
},
|
||||
{
|
||||
name: "WithPort",
|
||||
scheme: "http",
|
||||
host: "localhost:8080",
|
||||
path: "/api",
|
||||
expected: "http://localhost:8080/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var req *http.Request
|
||||
if tc.scheme == "https" {
|
||||
req = httptest.NewRequest("GET", "https://"+tc.host+tc.path, nil)
|
||||
} else {
|
||||
req = httptest.NewRequest("GET", "http://"+tc.host+tc.path, nil)
|
||||
}
|
||||
req.Host = tc.host
|
||||
|
||||
result := extractExpectedAudience(req)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user