feat: oauth2 - add RFC 8707 resource indicators and audience validation (#18575)

This pull request implements RFC 8707, Resource Indicators for OAuth 2.0 (https://datatracker.ietf.org/doc/html/rfc8707), to enhance the security of our OAuth 2.0 provider. 

This change enables proper audience validation and binds access tokens to their intended resource, which is crucial
  for preventing token misuse in multi-tenant environments or deployments with multiple resource servers.

##  Key Changes:


   * Resource Parameter Support: Adds support for the resource parameter in both the authorization (`/oauth2/authorize`) and token (`/oauth2/token`) endpoints, allowing clients to specify the intended resource server.
   * Audience Validation: Implements server-side validation to ensure that the resource parameter provided during the token exchange matches the one from the authorization request.
   * API Middleware Enforcement: Introduces a new validation step in the API authentication middleware (`coderd/httpmw/apikey.go`) to verify that the audience of the access token matches the resource server being accessed.
   * Database Schema Updates:
       * Adds a `resource_uri` column to the `oauth2_provider_app_codes` table to store the resource requested during authorization.
       * Adds an `audience` column to the `oauth2_provider_app_tokens` table to bind the issued token to a specific audience.
   * Enhanced PKCE: Includes a minor enhancement to the PKCE implementation to protect against timing attacks.
   * Comprehensive Testing: Adds extensive new tests to `coderd/oauth2_test.go` to cover various RFC 8707 scenarios, including valid flows, mismatched resources, and refresh token validation.

##  How it Works:


   1. An OAuth2 client specifies the target resource (e.g., https://coder.example.com) using the resource parameter in the authorization request.
   2. The authorization server stores this resource URI with the authorization code.
   3. During the token exchange, the server validates that the client provides the same resource parameter.
   4. The server issues an access token with an audience claim set to the validated resource URI.
   5. When the client uses the access token to call an API endpoint, the middleware verifies that the token's audience matches the URL of the Coder deployment, rejecting any tokens intended for a different resource.


  This ensures that a token issued for one Coder deployment cannot be used to access another, significantly strengthening our authentication security.

---

Change-Id: I3924cb2139e837e3ac0b0bd40a5aeb59637ebc1b
Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
Thomas Kosiewski
2025-07-02 17:49:00 +02:00
committed by GitHub
parent 01163ea57b
commit f0c9c4dbcd
22 changed files with 1008 additions and 57 deletions

View File

@ -15,9 +15,11 @@ import (
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/net/idna"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
@ -110,6 +112,9 @@ type ExtractAPIKeyConfig struct {
// This is originally implemented to send entitlement warning headers after
// a user is authenticated to prevent additional CLI invocations.
PostAuthAdditionalHeadersFunc func(a rbac.Subject, header http.Header)
// Logger is used for logging middleware operations.
Logger slog.Logger
}
// ExtractAPIKeyMW calls ExtractAPIKey with the given config on each request,
@ -240,6 +245,17 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
})
}
// Validate OAuth2 provider app token audience (RFC 8707) if applicable
if key.LoginType == database.LoginTypeOAuth2ProviderApp {
if err := validateOAuth2ProviderAppTokenAudience(ctx, cfg.DB, *key, r); err != nil {
// Log the detailed error for debugging but don't expose it to the client
cfg.Logger.Debug(ctx, "oauth2 token audience validation failed", slog.Error(err))
return optionalWrite(http.StatusForbidden, codersdk.Response{
Message: "Token audience validation failed",
})
}
}
// We only check OIDC stuff if we have a valid APIKey. An expired key means we don't trust the requestor
// really is the user whose key they have, and so we shouldn't be doing anything on their behalf including possibly
// refreshing the OIDC token.
@ -446,6 +462,160 @@ func ExtractAPIKey(rw http.ResponseWriter, r *http.Request, cfg ExtractAPIKeyCon
return key, &actor, true
}
// validateOAuth2ProviderAppTokenAudience validates that an OAuth2 provider app token
// is being used with the correct audience/resource server (RFC 8707).
func validateOAuth2ProviderAppTokenAudience(ctx context.Context, db database.Store, key database.APIKey, r *http.Request) error {
// Get the OAuth2 provider app token to check its audience
//nolint:gocritic // System needs to access token for audience validation
token, err := db.GetOAuth2ProviderAppTokenByAPIKeyID(dbauthz.AsSystemRestricted(ctx), key.ID)
if err != nil {
return xerrors.Errorf("failed to get OAuth2 token: %w", err)
}
// If no audience is set, allow the request (for backward compatibility)
if !token.Audience.Valid || token.Audience.String == "" {
return nil
}
// Extract the expected audience from the request
expectedAudience := extractExpectedAudience(r)
// Normalize both audience values for RFC 3986 compliant comparison
normalizedTokenAudience := normalizeAudienceURI(token.Audience.String)
normalizedExpectedAudience := normalizeAudienceURI(expectedAudience)
// Validate that the token's audience matches the expected audience
if normalizedTokenAudience != normalizedExpectedAudience {
return xerrors.Errorf("token audience %q does not match expected audience %q",
token.Audience.String, expectedAudience)
}
return nil
}
// normalizeAudienceURI implements RFC 3986 URI normalization for OAuth2 audience comparison.
// This ensures consistent audience matching between authorization and token validation.
func normalizeAudienceURI(audienceURI string) string {
if audienceURI == "" {
return ""
}
u, err := url.Parse(audienceURI)
if err != nil {
// If parsing fails, return as-is to avoid breaking existing functionality
return audienceURI
}
// Apply RFC 3986 syntax-based normalization:
// 1. Scheme normalization - case-insensitive
u.Scheme = strings.ToLower(u.Scheme)
// 2. Host normalization - case-insensitive and IDN (punnycode) normalization
u.Host = normalizeHost(u.Host)
// 3. Remove default ports for HTTP/HTTPS
if (u.Scheme == "http" && strings.HasSuffix(u.Host, ":80")) ||
(u.Scheme == "https" && strings.HasSuffix(u.Host, ":443")) {
// Extract host without default port
if idx := strings.LastIndex(u.Host, ":"); idx > 0 {
u.Host = u.Host[:idx]
}
}
// 4. Path normalization including dot-segment removal (RFC 3986 Section 6.2.2.3)
u.Path = normalizePathSegments(u.Path)
// 5. Remove fragment - should already be empty due to earlier validation,
// but clear it as a safety measure in case validation was bypassed
if u.Fragment != "" {
// This should not happen if validation is working correctly
u.Fragment = ""
}
// 6. Keep query parameters as-is (rarely used in audience URIs but preserved for compatibility)
return u.String()
}
// normalizeHost performs host normalization including case-insensitive conversion
// and IDN (Internationalized Domain Name) punnycode normalization.
func normalizeHost(host string) string {
if host == "" {
return host
}
// Handle IPv6 addresses - they are enclosed in brackets
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
// IPv6 addresses should be normalized to lowercase
return strings.ToLower(host)
}
// Extract port if present
var port string
if idx := strings.LastIndex(host, ":"); idx > 0 {
// Check if this is actually a port (not part of IPv6)
if !strings.Contains(host[idx+1:], ":") {
port = host[idx:]
host = host[:idx]
}
}
// Convert to lowercase for case-insensitive comparison
host = strings.ToLower(host)
// Apply IDN normalization - convert Unicode domain names to ASCII (punnycode)
if normalizedHost, err := idna.ToASCII(host); err == nil {
host = normalizedHost
}
// If IDN conversion fails, continue with lowercase version
return host + port
}
// normalizePathSegments normalizes path segments for consistent OAuth2 audience matching.
// Uses url.URL.ResolveReference() which implements RFC 3986 dot-segment removal.
func normalizePathSegments(path string) string {
if path == "" {
// If no path is specified, use "/" for consistency with RFC 8707 examples
return "/"
}
// Use url.URL.ResolveReference() to handle dot-segment removal per RFC 3986
base := &url.URL{Path: "/"}
ref := &url.URL{Path: path}
resolved := base.ResolveReference(ref)
normalizedPath := resolved.Path
// Remove trailing slash from paths longer than "/" to normalize
// This ensures "/api/" and "/api" are treated as equivalent
if len(normalizedPath) > 1 && strings.HasSuffix(normalizedPath, "/") {
normalizedPath = strings.TrimSuffix(normalizedPath, "/")
}
return normalizedPath
}
// Test export functions for testing package access
// extractExpectedAudience determines the expected audience for the current request.
// This should match the resource parameter used during authorization.
func extractExpectedAudience(r *http.Request) string {
// For MCP compliance, the audience should be the canonical URI of the resource server
// This typically matches the access URL of the Coder deployment
scheme := "https"
if r.TLS == nil {
scheme = "http"
}
// Use the Host header to construct the canonical audience URI
audience := fmt.Sprintf("%s://%s", scheme, r.Host)
// Normalize the URI according to RFC 3986 for consistent comparison
return normalizeAudienceURI(audience)
}
// UserRBACSubject fetches a user's rbac.Subject from the database. It pulls all roles from both
// site and organization scopes. It also pulls the groups, and the user's status.
func UserRBACSubject(ctx context.Context, db database.Store, userID uuid.UUID, scope rbac.ExpandableScope) (rbac.Subject, database.UserStatus, error) {

View File

@ -53,3 +53,213 @@ func TestParseUUID_Invalid(t *testing.T) {
require.NoError(t, err)
assert.Contains(t, response.Message, `Invalid UUID "wrong-id"`)
}
// TestNormalizeAudienceURI tests URI normalization for OAuth2 audience validation
func TestNormalizeAudienceURI(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "",
},
{
name: "SimpleHTTPWithoutTrailingSlash",
input: "http://example.com",
expected: "http://example.com/",
},
{
name: "SimpleHTTPWithTrailingSlash",
input: "http://example.com/",
expected: "http://example.com/",
},
{
name: "HTTPSWithPath",
input: "https://api.example.com/v1/",
expected: "https://api.example.com/v1",
},
{
name: "CaseNormalization",
input: "HTTPS://API.EXAMPLE.COM/V1/",
expected: "https://api.example.com/V1",
},
{
name: "DefaultHTTPPort",
input: "http://example.com:80/api/",
expected: "http://example.com/api",
},
{
name: "DefaultHTTPSPort",
input: "https://example.com:443/api/",
expected: "https://example.com/api",
},
{
name: "NonDefaultPort",
input: "http://example.com:8080/api/",
expected: "http://example.com:8080/api",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizeAudienceURI(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestNormalizeHost tests host normalization including IDN support
func TestNormalizeHost(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "",
},
{
name: "SimpleHost",
input: "example.com",
expected: "example.com",
},
{
name: "HostWithPort",
input: "example.com:8080",
expected: "example.com:8080",
},
{
name: "CaseNormalization",
input: "EXAMPLE.COM",
expected: "example.com",
},
{
name: "IPv4Address",
input: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6Address",
input: "[::1]:8080",
expected: "[::1]:8080",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizeHost(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestNormalizePathSegments tests path normalization including dot-segment removal
func TestNormalizePathSegments(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input string
expected string
}{
{
name: "EmptyString",
input: "",
expected: "/",
},
{
name: "SimplePath",
input: "/api/v1",
expected: "/api/v1",
},
{
name: "PathWithDotSegments",
input: "/api/../v1/./test",
expected: "/v1/test",
},
{
name: "TrailingSlash",
input: "/api/v1/",
expected: "/api/v1",
},
{
name: "MultipleSlashes",
input: "/api//v1///test",
expected: "/api//v1///test",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := normalizePathSegments(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}
// TestExtractExpectedAudience tests audience extraction from HTTP requests
func TestExtractExpectedAudience(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
scheme string
host string
path string
expected string
}{
{
name: "SimpleHTTP",
scheme: "http",
host: "example.com",
path: "/api/test",
expected: "http://example.com/",
},
{
name: "HTTPS",
scheme: "https",
host: "api.example.com",
path: "/v1/users",
expected: "https://api.example.com/",
},
{
name: "WithPort",
scheme: "http",
host: "localhost:8080",
path: "/api",
expected: "http://localhost:8080/",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var req *http.Request
if tc.scheme == "https" {
req = httptest.NewRequest("GET", "https://"+tc.host+tc.path, nil)
} else {
req = httptest.NewRequest("GET", "http://"+tc.host+tc.path, nil)
}
req.Host = tc.host
result := extractExpectedAudience(req)
assert.Equal(t, tc.expected, result)
})
}
}