mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
This pull request implements RFC 8707, Resource Indicators for OAuth 2.0 (https://datatracker.ietf.org/doc/html/rfc8707), to enhance the security of our OAuth 2.0 provider. This change enables proper audience validation and binds access tokens to their intended resource, which is crucial for preventing token misuse in multi-tenant environments or deployments with multiple resource servers. ## Key Changes: * Resource Parameter Support: Adds support for the resource parameter in both the authorization (`/oauth2/authorize`) and token (`/oauth2/token`) endpoints, allowing clients to specify the intended resource server. * Audience Validation: Implements server-side validation to ensure that the resource parameter provided during the token exchange matches the one from the authorization request. * API Middleware Enforcement: Introduces a new validation step in the API authentication middleware (`coderd/httpmw/apikey.go`) to verify that the audience of the access token matches the resource server being accessed. * Database Schema Updates: * Adds a `resource_uri` column to the `oauth2_provider_app_codes` table to store the resource requested during authorization. * Adds an `audience` column to the `oauth2_provider_app_tokens` table to bind the issued token to a specific audience. * Enhanced PKCE: Includes a minor enhancement to the PKCE implementation to protect against timing attacks. * Comprehensive Testing: Adds extensive new tests to `coderd/oauth2_test.go` to cover various RFC 8707 scenarios, including valid flows, mismatched resources, and refresh token validation. ## How it Works: 1. An OAuth2 client specifies the target resource (e.g., https://coder.example.com) using the resource parameter in the authorization request. 2. The authorization server stores this resource URI with the authorization code. 3. During the token exchange, the server validates that the client provides the same resource parameter. 4. The server issues an access token with an audience claim set to the validated resource URI. 5. When the client uses the access token to call an API endpoint, the middleware verifies that the token's audience matches the URL of the Coder deployment, rejecting any tokens intended for a different resource. This ensures that a token issued for one Coder deployment cannot be used to access another, significantly strengthening our authentication security. --- Change-Id: I3924cb2139e837e3ac0b0bd40a5aeb59637ebc1b Signed-off-by: Thomas Kosiewski <tk@coder.com>
266 lines
5.6 KiB
Go
266 lines
5.6 KiB
Go
package httpmw
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
const (
|
|
testParam = "workspaceagent"
|
|
testWorkspaceAgentID = "8a70c576-12dc-42bc-b791-112a32b5bd43"
|
|
)
|
|
|
|
func TestParseUUID_Valid(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rw := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
|
|
|
|
ctx := chi.NewRouteContext()
|
|
ctx.URLParams.Add(testParam, testWorkspaceAgentID)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
|
|
|
parsed, ok := ParseUUIDParam(rw, r, "workspaceagent")
|
|
assert.True(t, ok, "UUID should be parsed")
|
|
assert.Equal(t, testWorkspaceAgentID, parsed.String())
|
|
}
|
|
|
|
func TestParseUUID_Invalid(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rw := httptest.NewRecorder()
|
|
r := httptest.NewRequest("GET", "/{workspaceagent}", nil)
|
|
|
|
ctx := chi.NewRouteContext()
|
|
ctx.URLParams.Add(testParam, "wrong-id")
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
|
|
|
_, ok := ParseUUIDParam(rw, r, "workspaceagent")
|
|
assert.False(t, ok, "UUID should not be parsed")
|
|
assert.Equal(t, http.StatusBadRequest, rw.Code)
|
|
|
|
var response codersdk.Response
|
|
err := json.Unmarshal(rw.Body.Bytes(), &response)
|
|
require.NoError(t, err)
|
|
assert.Contains(t, response.Message, `Invalid UUID "wrong-id"`)
|
|
}
|
|
|
|
// TestNormalizeAudienceURI tests URI normalization for OAuth2 audience validation
|
|
func TestNormalizeAudienceURI(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "EmptyString",
|
|
input: "",
|
|
expected: "",
|
|
},
|
|
{
|
|
name: "SimpleHTTPWithoutTrailingSlash",
|
|
input: "http://example.com",
|
|
expected: "http://example.com/",
|
|
},
|
|
{
|
|
name: "SimpleHTTPWithTrailingSlash",
|
|
input: "http://example.com/",
|
|
expected: "http://example.com/",
|
|
},
|
|
{
|
|
name: "HTTPSWithPath",
|
|
input: "https://api.example.com/v1/",
|
|
expected: "https://api.example.com/v1",
|
|
},
|
|
{
|
|
name: "CaseNormalization",
|
|
input: "HTTPS://API.EXAMPLE.COM/V1/",
|
|
expected: "https://api.example.com/V1",
|
|
},
|
|
{
|
|
name: "DefaultHTTPPort",
|
|
input: "http://example.com:80/api/",
|
|
expected: "http://example.com/api",
|
|
},
|
|
{
|
|
name: "DefaultHTTPSPort",
|
|
input: "https://example.com:443/api/",
|
|
expected: "https://example.com/api",
|
|
},
|
|
{
|
|
name: "NonDefaultPort",
|
|
input: "http://example.com:8080/api/",
|
|
expected: "http://example.com:8080/api",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := normalizeAudienceURI(tc.input)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestNormalizeHost tests host normalization including IDN support
|
|
func TestNormalizeHost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "EmptyString",
|
|
input: "",
|
|
expected: "",
|
|
},
|
|
{
|
|
name: "SimpleHost",
|
|
input: "example.com",
|
|
expected: "example.com",
|
|
},
|
|
{
|
|
name: "HostWithPort",
|
|
input: "example.com:8080",
|
|
expected: "example.com:8080",
|
|
},
|
|
{
|
|
name: "CaseNormalization",
|
|
input: "EXAMPLE.COM",
|
|
expected: "example.com",
|
|
},
|
|
{
|
|
name: "IPv4Address",
|
|
input: "192.168.1.1",
|
|
expected: "192.168.1.1",
|
|
},
|
|
{
|
|
name: "IPv6Address",
|
|
input: "[::1]:8080",
|
|
expected: "[::1]:8080",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := normalizeHost(tc.input)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestNormalizePathSegments tests path normalization including dot-segment removal
|
|
func TestNormalizePathSegments(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "EmptyString",
|
|
input: "",
|
|
expected: "/",
|
|
},
|
|
{
|
|
name: "SimplePath",
|
|
input: "/api/v1",
|
|
expected: "/api/v1",
|
|
},
|
|
{
|
|
name: "PathWithDotSegments",
|
|
input: "/api/../v1/./test",
|
|
expected: "/v1/test",
|
|
},
|
|
{
|
|
name: "TrailingSlash",
|
|
input: "/api/v1/",
|
|
expected: "/api/v1",
|
|
},
|
|
{
|
|
name: "MultipleSlashes",
|
|
input: "/api//v1///test",
|
|
expected: "/api//v1///test",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
result := normalizePathSegments(tc.input)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestExtractExpectedAudience tests audience extraction from HTTP requests
|
|
func TestExtractExpectedAudience(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCases := []struct {
|
|
name string
|
|
scheme string
|
|
host string
|
|
path string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "SimpleHTTP",
|
|
scheme: "http",
|
|
host: "example.com",
|
|
path: "/api/test",
|
|
expected: "http://example.com/",
|
|
},
|
|
{
|
|
name: "HTTPS",
|
|
scheme: "https",
|
|
host: "api.example.com",
|
|
path: "/v1/users",
|
|
expected: "https://api.example.com/",
|
|
},
|
|
{
|
|
name: "WithPort",
|
|
scheme: "http",
|
|
host: "localhost:8080",
|
|
path: "/api",
|
|
expected: "http://localhost:8080/",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
var req *http.Request
|
|
if tc.scheme == "https" {
|
|
req = httptest.NewRequest("GET", "https://"+tc.host+tc.path, nil)
|
|
} else {
|
|
req = httptest.NewRequest("GET", "http://"+tc.host+tc.path, nil)
|
|
}
|
|
req.Host = tc.host
|
|
|
|
result := extractExpectedAudience(req)
|
|
assert.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|