feat: add azure oidc PKI auth instead of client secret (#9054)

* feat: add azure oidc PKI auth instead of client secret
* add client cert and key as deployment options
* Custom token refresher to handle pki auth
This commit is contained in:
Steven Masley
2023-08-14 17:33:13 -05:00
committed by GitHub
parent 4e36f91ea2
commit 25ce30df36
13 changed files with 748 additions and 35 deletions

7
coderd/apidoc/docs.go generated
View File

@ -8602,9 +8602,16 @@ const docTemplate = `{
"auth_url_params": {
"type": "object"
},
"client_cert_file": {
"type": "string"
},
"client_id": {
"type": "string"
},
"client_key_file": {
"description": "ClientKeyFile \u0026 ClientCertFile are used in place of ClientSecret for PKI auth.",
"type": "string"
},
"client_secret": {
"type": "string"
},

View File

@ -7715,9 +7715,16 @@
"auth_url_params": {
"type": "object"
},
"client_cert_file": {
"type": "string"
},
"client_id": {
"type": "string"
},
"client_key_file": {
"description": "ClientKeyFile \u0026 ClientCertFile are used in place of ClientSecret for PKI auth.",
"type": "string"
},
"client_secret": {
"type": "string"
},

273
coderd/oauthpki/oidcpki.go Normal file
View File

@ -0,0 +1,273 @@
package oauthpki
import (
"context"
"crypto/rsa"
"crypto/sha1" //#nosec // Not used for cryptography.
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/oauth2"
"golang.org/x/oauth2/jws"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/httpmw"
)
// Config uses jwt assertions over client_secret for oauth2 authentication of
// the application. This implementation was made specifically for Azure AD.
//
// https://learn.microsoft.com/en-us/azure/active-directory/develop/certificate-credentials
//
// However this does mostly follow the standard. We can generalize this as we
// include support for more IDPs.
//
// https://datatracker.ietf.org/doc/html/rfc7523
type Config struct {
cfg httpmw.OAuth2Config
// These values should match those provided in the oauth2.Config.
// Because the inner config is an interface, we need to duplicate these
// values here.
scopes []string
clientID string
tokenURL string
// ClientSecret is the private key of the PKI cert.
// Azure AD only supports RS256 signing algorithm.
clientKey *rsa.PrivateKey
// Base64url-encoded SHA-1 thumbprint of the X.509 certificate's DER encoding.
// This is specific to Azure AD
x5t string
}
type ConfigParams struct {
ClientID string
TokenURL string
Scopes []string
PemEncodedKey []byte
PemEncodedCert []byte
Config httpmw.OAuth2Config
}
// NewOauth2PKIConfig creates the oauth2 config for PKI based auth. It requires the certificate and it's private key.
// The values should be passed in as PEM encoded values, which is the standard encoding for x509 certs saved to disk.
// It should look like:
//
// -----BEGIN RSA PRIVATE KEY----
// ...
// -----END RSA PRIVATE KEY-----
//
// -----BEGIN CERTIFICATE-----
// ...
// -----END CERTIFICATE-----
func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
if params.ClientID == "" {
return nil, xerrors.Errorf("")
}
if len(params.Scopes) == 0 {
return nil, xerrors.Errorf("scopes are required")
}
rsaKey, err := decodeClientKey(params.PemEncodedKey)
if err != nil {
return nil, err
}
// Azure AD requires a certificate. The sha1 of the cert is used to identify the signer.
// This is not required in the general specification.
if strings.Contains(strings.ToLower(params.TokenURL), "microsoftonline") && len(params.PemEncodedCert) == 0 {
return nil, xerrors.Errorf("oidc client certificate is required and missing")
}
block, _ := pem.Decode(params.PemEncodedCert)
// Used as an identifier, not an actual cryptographic hash.
//nolint:gosec
hashed := sha1.Sum(block.Bytes)
return &Config{
clientID: params.ClientID,
tokenURL: params.TokenURL,
scopes: params.Scopes,
cfg: params.Config,
clientKey: rsaKey,
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
}, nil
}
// decodeClientKey decodes a PEM encoded rsa secret.
func decodeClientKey(pemEncoded []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(pemEncoded)
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, xerrors.Errorf("failed to parse private key: %w", err)
}
return key, nil
}
func (ja *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return ja.cfg.AuthCodeURL(state, opts...)
}
// Exchange includes the client_assertion signed JWT.
func (ja *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
signed, err := ja.jwtToken()
if err != nil {
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
}
opts = append(opts,
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
oauth2.SetAuthURLParam("client_assertion", signed),
)
return ja.cfg.Exchange(ctx, code, opts...)
}
func (ja *Config) jwtToken() (string, error) {
now := time.Now()
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"iss": ja.clientID,
"sub": ja.clientID,
"aud": ja.tokenURL,
// 5-10 minutes is recommended in the Azure docs.
// So we'll use 5 minutes.
"exp": now.Add(time.Minute * 5).Unix(),
"jti": uuid.New().String(),
"nbf": now.Unix(),
"iat": now.Unix(),
})
token.Header["x5t"] = ja.x5t
signed, err := token.SignedString(ja.clientKey)
if err != nil {
return "", xerrors.Errorf("sign jwt assertion: %w", err)
}
return signed, nil
}
func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
return oauth2.ReuseTokenSource(token, &jwtTokenSource{
cfg: ja,
ctx: ctx,
refreshToken: token.RefreshToken,
})
}
type jwtTokenSource struct {
cfg *Config
ctx context.Context
refreshToken string
}
// Token must be safe for concurrent use by multiple go routines
// Very similar to the RetrieveToken implementation by the oauth2 package.
// https://github.com/golang/oauth2/blob/master/internal/token.go#L212
// Oauth2 package keeps this code unexported or in an /internal package,
// so we have to copy the implementation :(
func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
if src.refreshToken == "" {
return nil, xerrors.New("oauth2: token expired and refresh token is not set")
}
cli := http.DefaultClient
if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
cli = v
}
token, err := src.cfg.jwtToken()
if err != nil {
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
}
v := url.Values{
"client_assertion": {token},
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
"client_id": {src.cfg.clientID},
"grant_type": {"refresh_token"},
"scope": {strings.Join(src.cfg.scopes, " ")},
"refresh_token": {src.refreshToken},
}
// Using params based auth
req, err := http.NewRequest("POST", src.cfg.tokenURL, strings.NewReader(v.Encode()))
if err != nil {
return nil, xerrors.Errorf("oauth2: make token refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(src.ctx)
resp, err := cli.Do(req)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot get token: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, xerrors.Errorf("oauth2: cannot fetch token reading response body: %w", err)
}
var tokenRes struct {
oauth2.Token
// Extra fields returned by the refresh that are needed
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
// error fields
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
ErrorCode string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
}
unmarshalError := json.Unmarshal(body, &tokenRes)
if resp.StatusCode < 200 || resp.StatusCode > 299 {
// Return a standard oauth2 error. Attempt to read some error fields. The error fields
// can be encoded in a few places, so this does not catch all of them.
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
// Best effort for error fields
ErrorCode: tokenRes.ErrorCode,
ErrorDescription: tokenRes.ErrorDescription,
ErrorURI: tokenRes.ErrorURI,
}
}
if unmarshalError != nil {
return nil, fmt.Errorf("oauth2: cannot unmarshal token: %w", err)
}
newToken := &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
}
if secs := tokenRes.ExpiresIn; secs > 0 {
newToken.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
}
// ID token is a JWT token. We can decode it to get the expiry.
// Not really sure what to do if the ExpiresIn and JWT expiry differ,
// but this one is attached in the JWT and guaranteed to be right for local
// validation. So use this one if found.
if v := tokenRes.IDToken; v != "" {
// decode returned id token to get expiry
claimSet, err := jws.Decode(v)
if err != nil {
return nil, fmt.Errorf("oauth2: error decoding JWT token: %w", err)
}
newToken.Expiry = time.Unix(claimSet.Exp, 0)
}
return newToken, nil
}

View File

@ -0,0 +1,296 @@
package oauthpki_test
import (
"context"
"encoding/base64"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/oauthpki"
"github.com/coder/coder/testutil"
)
const (
testClientKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAnUryZEfn5kA8wuk9a7ogFuWbk3uPHEhioYuAg9m3/tIdqSqu
ASpRzw8+1nORTf3ykWRRlhxZWnKimmkB0Ux5Yrz9TDVWDQbzEH3B8ibMlmaNcoN8
wYVzeEpqCe3fJagnV0lh0sHB1Z+vhcJ/M2nEAdyfhIgQEbG6Xtl2+WcGqyMWUJpV
g8+ebK+JkXELAGN1hg3DdV52gjodEjoe1/ibHz8y3NR7j2tOKix7iKOhccyFkD35
xqSnfyZJK5yxIfmGiWdVOIGqc2rYpgvrXJLTOjLoeyDSNi+Q604T64ZxsqfuM4LX
BakVG3EwHFXPBfsBKjUE9HYvXEXw3fJP9K6mIwIDAQABAoIBAQCb+aH7x0IylSir
r1Z06RDBI9bunOwBA9aqkwdRuCg4zGsVQXljNnABgACz7837JQPRIUW2MU553otX
yyE+RzNnsjkLxSgbqvSFOe+FDOx7iB5jm/euf4NNmZ0lU3iggurgJ6iVsgVgrQUF
AyXX+d2gawLUDYjBwxgozkSodH2sXYSX+SWfSOXHsFzSa3tLtUMbAIflM0rlRXf7
Z57M8mMomZUvmmojH+TnBQljJlU8lhrvOaDD4DT8qAtVHE3VluDBQ9/3E8OIjz+E
EqUgWLgrdq1rIMhJbHN90NwLwWs+2PcRfdB6hqKPktLne2KZFOgVKlxPKOYByBq1
PX/vJ/HBAoGBAMFmJ6nYqyUVl26ajlXmnXBjQ+iBLHo9lcUu84+rpqRf90Bsm5bd
jMmYr3Yo3yXNiit3rvZzBfPElo+IVa1HpPtgOaa2AU5B3QzxWCNT0FNRQqMG2LcA
CvB10pOdJEABQxr7d4eFRg2/KbF1fr0r0vqMEelwa5ejTg6ROD3DtadpAoGBANA0
4EClniCwvd1IECy2oTuTDosXgmRKwRAcwgE34YXy1Y/L4X/ghFeCHi3ybrep0uwL
ptJNK+0sqvPu6UhC356GfMqfuzOKNMkXybnPUbHrz5KTkN+QQMfPc73Veel2gpD3
xNataEmHtxcOx0X1OnjwyZZpmMbrUY3Cackn+durAoGBAKYR5nU+jJfnloVvSlIR
GZhsZN++LEc7ouQTkSoJp6r2jQZRPLmrvT1PUzwPlK6NdNwmhaMy2iWc5fySgZ+u
KcmBs3+oQi7E9+ApThnn2rfwy1vagTWDX+FkC1KeWYZsjwcYcGd61dDwGgk8b3xZ
qW1j4e2mj31CycBQiw7eg5ohAoGADvkOe3etlHpBXS12hFCp7afYruYE6YN6uNbo
mL/VBxX8h7fIwrJ5sfVYiENb9PdQhMsdtxf3pbnFnX875Ydxn2vag5PTGZTB0QhV
6HfhTyM/LTJRg9JS5kuj7i3w83ojT5uR20JjMo6A+zaD3CMTjmj6hkeXxg5cMg6e
HuoyDLsCgYBcbboYMFT1cUSxBeMtPGt3CxxZUYnUQaRUeOcjqYYlFL+DCWhY7pxH
EnLhwW/KzkDzOmwRmmNOMqD7UhR/ayxR+avRt6v5d5l8fVCuNexgs7kR9L5IQp9l
YV2wsCoXBCcuPmio/te44U//BlzprEu0w1iHpb3ibmQg4y291R0TvQ==
-----END RSA PRIVATE KEY-----`
testClientCert = `
-----BEGIN CERTIFICATE-----
MIIEOjCCAiKgAwIBAgIQMO50KnWsRbmrrthPQgyubjANBgkqhkiG9w0BAQsFADAY
MRYwFAYDVQQDEw1Mb2NhbGhvc3RDZXJ0MB4XDTIzMDgxMDE2MjYxOFoXDTI1MDIx
MDE2MjU0M1owFDESMBAGA1UEAxMJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
AAOCAQ8AMIIBCgKCAQEAnUryZEfn5kA8wuk9a7ogFuWbk3uPHEhioYuAg9m3/tId
qSquASpRzw8+1nORTf3ykWRRlhxZWnKimmkB0Ux5Yrz9TDVWDQbzEH3B8ibMlmaN
coN8wYVzeEpqCe3fJagnV0lh0sHB1Z+vhcJ/M2nEAdyfhIgQEbG6Xtl2+WcGqyMW
UJpVg8+ebK+JkXELAGN1hg3DdV52gjodEjoe1/ibHz8y3NR7j2tOKix7iKOhccyF
kD35xqSnfyZJK5yxIfmGiWdVOIGqc2rYpgvrXJLTOjLoeyDSNi+Q604T64Zxsqfu
M4LXBakVG3EwHFXPBfsBKjUE9HYvXEXw3fJP9K6mIwIDAQABo4GDMIGAMA4GA1Ud
DwEB/wQEAwIDuDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwHQYDVR0O
BBYEFAYCdgydG3h2SNWF+BfAyJtNliJtMB8GA1UdIwQYMBaAFHR/aptP0RUNNFyf
5uky527SECt1MA8GA1UdEQQIMAaHBH8AAAEwDQYJKoZIhvcNAQELBQADggIBAI6P
ymG7l06JvJ3p6xgaMyOxgkpQl6WkY4LJHVEhfeDSoO3qsJc4PxUdSExJsT84weXb
lF+tK6D/CPlvjmG720IlB5cSKJ71rWjwmaMWKxWKXyoZdDrHAp55+FNdXegUZF2o
EF/ZM5CHaO8iHMkuWEv1OASHBQWC/o4spUN5HGQ9HepwLVvO/aX++LYfvfL9faKA
IT+w9i8pJbfItFmfA8x2OEVZk8aEA0WtKdfsMwzGmZ1GSGa4UYcynxQGCMiB5h4L
C/dpoJRbEzdGLuTZgV2SCaN3k5BrH4aaILI9tqZaq0gamN9Rd2yji3cGiduCeAAo
RmVcl9fBliMLxylWEP5+B2JmCZEc8Lfm0TBNnjaG17KY40gzbfBYixBxBTYgsPua
bfprtfksSG++zcsDbkC8CtPamtlNWtDAiFp4yQRkP79PlJO6qCdTrFWPukTMCMso
25hjLvxj1fLy/jSMDEZu/oQ14TMCZSGHRjz4CPiaCfXqgqOtVOD+5+yWInwUGp/i
Nb1vIq4ruEAbyCbdWKHbE0yT5AP7hm5ZNybpZ4/311AEBD2HKip/OqB05p99XcLw
BIC4ODNvwCn6x00KZoqWz/MX2dEQ/HqWiWaDB/OSemfTVE3I94mzEWnqpF2cQpcT
B1B7CpkMU55hPP+7nsofCszNrMDXT8Z5w2a3zLKM
-----END CERTIFICATE-----
`
)
// TestAzureADPKIOIDC ensures we do not break Azure AD compatibility.
// It runs an oauth2.Exchange method and hijacks the request to only check the
// request side of the transaction.
func TestAzureADPKIOIDC(t *testing.T) {
t.Parallel()
oauthCfg := &oauth2.Config{
ClientID: "random-client-id",
Endpoint: oauth2.Endpoint{
TokenURL: "https://login.microsoftonline.com/6a1e9139-13f2-4afb-8f46-036feac8bd79/v2.0/token",
},
}
pkiConfig, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
ClientID: oauthCfg.ClientID,
TokenURL: oauthCfg.Endpoint.TokenURL,
PemEncodedKey: []byte(testClientKey),
PemEncodedCert: []byte(testClientCert),
Config: oauthCfg,
Scopes: []string{"openid", "email", "profile"},
})
require.NoError(t, err, "failed to create pki config")
ctx := testutil.Context(t, testutil.WaitMedium)
ctx = oidc.ClientContext(ctx, &http.Client{
Transport: &fakeRoundTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
resp := &http.Response{
Status: "500 Internal Service Error",
}
// This is the easiest way to hijack the request and check
// the params. The oauth2 package uses unexported types and
// options, so we need to view the actual request created.
assertJWTAuth(t, req)
return resp, nil
},
},
})
_, err = pkiConfig.Exchange(ctx, base64.StdEncoding.EncodeToString([]byte("random-code")))
// We hijack the request and return an error intentionally
require.Error(t, err, "error expected")
}
// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
// AD instance and saving them to replay, removing some details.
// The reason this is done is that this is the only way to assert values
// passed to the oauth2 provider via http requests.
// It is not feasible to run against an actual Azure AD instance, so this attempts
// to prevent some regressions by running a full "e2e" oauth and asserting some
// of the request values.
func TestSavedAzureADPKIOIDC(t *testing.T) {
t.Parallel()
var (
stateString = "random-state"
oauth2Code = base64.StdEncoding.EncodeToString([]byte("random-code"))
)
// Real oauth config. We will hijack all http requests so some of these values
// are fake.
cfg := &oauth2.Config{
ClientID: "fake_app",
ClientSecret: "",
Endpoint: oauth2.Endpoint{
AuthURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/authorize",
TokenURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/token",
AuthStyle: 0,
},
RedirectURL: "http://localhost/api/v2/users/oidc/callback",
Scopes: []string{"openid", "profile", "email", "offline_access"},
}
initialExchange := false
tokenRefreshed := false
// Create the oauthpki config
pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
ClientID: cfg.ClientID,
TokenURL: cfg.Endpoint.TokenURL,
Scopes: []string{"openid", "email", "profile", "offline_access"},
PemEncodedKey: []byte(testClientKey),
PemEncodedCert: []byte(testClientCert),
Config: cfg,
})
require.NoError(t, err)
var fakeCtx context.Context
fakeClient := &http.Client{
Transport: fakeRoundTripper{
roundTrip: func(req *http.Request) (*http.Response, error) {
if strings.Contains(req.URL.String(), "authorize") {
// This is the user hitting the browser endpoint to begin the OIDC
// auth flow.
// Authorize should redirect the user back to the app after authentication on
// the IDP.
resp := httptest.NewRecorder()
v := url.Values{
"code": {oauth2Code},
"state": {stateString},
"session_state": {"a18cf797-1e2b-4bc3-baf9-66b41a4997cf"},
}
// This url doesn't really matter since the fake client will hiject this actual request.
http.Redirect(resp, req, "http://localhost:3000/api/v2/users/oidc/callback?"+v.Encode(), http.StatusTemporaryRedirect)
return resp.Result(), nil
}
if strings.Contains(req.URL.String(), "v2.0/token") {
vals := assertJWTAuth(t, req)
switch vals.Get("grant_type") {
case "authorization_code":
// Initial token
initialExchange = true
assert.Equal(t, oauth2Code, vals.Get("code"), "initial exchange code mismatch")
case "refresh_token":
// refreshed token
tokenRefreshed = true
assert.Equal(t, "<refresh_token_JWT>", vals.Get("refresh_token"), "refresh token required")
}
resp := httptest.NewRecorder()
// Taken from an actual response
// Just always return a token no matter what.
resp.Header().Set("Content-Type", "application/json")
_, _ = resp.Write([]byte(`{
"token_type":"Bearer",
"scope":"email openid profile AccessReview.ReadWrite.Membership Group.Read.All Group.ReadWrite.All User.Read",
"expires_in":4009,
"ext_expires_in":4009,
"access_token":"<access_token_JWT>",
"refresh_token":"<refresh_token_JWT>",
"id_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ii1LSTNROW5OUjdiUm9meG1lWm9YcWJIWkdldyJ9.eyJhdWQiOiIxZjAxODMyYS1mZWViLTQyZGMtODFkOS01ZjBhYjZhMDQxZTAiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vMTEwZjBjMGYtY2Q3Ni00NzE3LWE2ZjgtNGVlYTNkMGY4MTA5L3YyLjAiLCJpYXQiOjE2OTE3OTI2MzQsIm5iZiI6MTY5MTc5MjYzNCwiZXhwIjoxNjkxNzk2NTM0LCJhaW8iOiJBWVFBZS84VUFBQUE1eEtqMmVTdWFXVmZsRlhCeGJJTnMvSkVyVHFvUGlaQW5ENmJIZWF3a2RRcisyRVRwM3RGNGY3akxicnh3ODhhVm9QOThrY0xMNjhON1hVV3FCN1I1N2JQRU9EclRlSUI1S0lyUHBjbCtIeXR0a1ljOVdWQklVVEErSllQbzl1a0ZjbGNWZ1krWUc3eHlmdi90K3Q1ZEczblNuZEdEQ1FYRVIxbDlTNko1T2c9IiwiZW1haWwiOiJzdGV2ZW5AY29kZXIuY29tIiwiZ3JvdXBzIjpbImM4MDQ4ZTkxLWY1YzMtNDdlNS05NjkzLTgzNGRlODQwMzRhZCIsIjcwYjQ4MTc1LTEwN2ItNGFkOC1iNDA1LTRkODg4YTFjNDY2ZiJdLCJpZHAiOiJtYWlsIiwibmFtZSI6IlN0ZXZlbiBNIiwib2lkIjoiN2JhNDYzNjAtZTAyNy00OTVhLTlhZTUtM2FlYWZlMzY3MGEyIiwicHJlZmVycmVkX3VzZXJuYW1lIjoic3RldmVuQGNvZGVyLmNvbSIsInByb3ZfZGF0YSI6W3siQXQiOnRydWUsIlByb3YiOiJnaXRodWIuY29tIiwiQWx0c2VjaWQiOiI1NDQ2Mjk4IiwiQWNjZXNzVG9rZW4iOm51bGx9XSwicmgiOiIwLkFUZ0FEd3dQRVhiTkYwZW0tRTdxUFEtQkNTcURBUl9yX3R4Q2dkbGZDcmFnUWVBNEFPRS4iLCJyb2xlcyI6WyJUZW1wbGF0ZUF1dGhvcnMiXSwic3ViIjoib0JTN3FjUERKdWlDMEYyQ19XdDJycVlvanhpT0o3S3JFWjlkQ1RkTGVYNCIsInRpZCI6IjExMGYwYzBmLWNkNzYtNDcxNy1hNmY4LTRlZWEzZDBmODEwOSIsInV0aSI6IktReGlIWGtaZUVxcC1tQWlVdTlyQUEiLCJ2ZXIiOiIyLjAiLCJyb2xlczIiOiJUZW1wbGF0ZUF1dGhvcnMifQ.JevFI4Xm9dW7kQq4xEgZnUaU0SqbeOAFtT0YIKQNefR9Db4sjxCaKRmX0pPt-CM9j45d6fAiAkLFDAqjlSbi4Zi0GbEomT3yegmuxKgEgjPpJlGjF2TBUpsNNyn5gJ9Wkct9BfwALJhX2ePJFzIlkvx9opNNbNK1qHKMMjOSRFG6AGExKRDiQAME0a4hVgCwrAdUs4JrCcj4LqB84dODN-eoh-jx2-1wDvf6fovfwLHDQwjY4lfBxaYdNavKM369hrhU-U067rSnCzvDD26f4VLhPF52hiQIbTVN5t7p_1XmcduUiaNnmr9AZiZxZ-94mctSRRR8xG0pNwO2yv84iA"
}`))
return resp.Result(), nil
}
// This is the "Coder" half of things. We can keep this in the fake
// client, essentially being the fake client on both sides of the OIDC
// flow.
if strings.Contains(req.URL.String(), "v2/users/oidc/callback") {
// This is the callback from the IDP.
code := req.URL.Query().Get("code")
require.Equal(t, oauth2Code, code, "code mismatch")
state := req.URL.Query().Get("state")
require.Equal(t, stateString, state, "state mismatch")
// Exchange for token should work
token, err := pki.Exchange(fakeCtx, code)
if !assert.NoError(t, err) {
return httptest.NewRecorder().Result(), nil
}
// Also try a refresh
cpy := token
cpy.Expiry = time.Now().Add(time.Minute * -1)
src := pki.TokenSource(fakeCtx, cpy)
_, err = src.Token()
tokenRefreshed = true
assert.NoError(t, err, "token refreshed")
return httptest.NewRecorder().Result(), nil
}
return nil, xerrors.Errorf("not implemented")
},
},
}
fakeCtx = oidc.ClientContext(context.Background(), fakeClient)
_ = fakeCtx
// This simulates a client logging into the browser. The 307 redirect will
// make sure this goes through the full flow.
// nolint: noctx
resp, err := fakeClient.Get(pki.AuthCodeURL("state", oauth2.AccessTypeOffline))
require.NoError(t, err)
_ = resp.Body.Close()
require.True(t, initialExchange, "initial token exchange complete")
require.True(t, tokenRefreshed, "token was refreshed")
}
type fakeRoundTripper struct {
roundTrip func(req *http.Request) (*http.Response, error)
}
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f.roundTrip(req)
}
// assertJWTAuth will assert the basic JWT auth assertions. It will return the
// url.Values from the request body for any additional assertions to be made.
func assertJWTAuth(t *testing.T, r *http.Request) url.Values {
body, err := io.ReadAll(r.Body)
if !assert.NoError(t, err) {
return nil
}
vals, err := url.ParseQuery(string(body))
if !assert.NoError(t, err) {
return nil
}
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", vals.Get("client_assertion_type"))
jwtToken := vals.Get("client_assertion")
// No need to actually verify the jwt is signed right.
parsedToken, _, err := (&jwt.Parser{}).ParseUnverified(jwtToken, jwt.MapClaims{})
if !assert.NoError(t, err, "failed to parse jwt token") {
return nil
}
// Azure requirements
assert.NotEmpty(t, parsedToken.Header["x5t"], "hashed cert missing")
assert.Equal(t, "RS256", parsedToken.Header["alg"], "azure only accepts RS256")
assert.Equal(t, "JWT", parsedToken.Header["typ"], "azure only accepts JWT")
return vals
}