mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
feat: add azure oidc PKI auth instead of client secret (#9054)
* feat: add azure oidc PKI auth instead of client secret * add client cert and key as deployment options * Custom token refresher to handle pki auth
This commit is contained in:
7
coderd/apidoc/docs.go
generated
7
coderd/apidoc/docs.go
generated
@ -8602,9 +8602,16 @@ const docTemplate = `{
|
||||
"auth_url_params": {
|
||||
"type": "object"
|
||||
},
|
||||
"client_cert_file": {
|
||||
"type": "string"
|
||||
},
|
||||
"client_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"client_key_file": {
|
||||
"description": "ClientKeyFile \u0026 ClientCertFile are used in place of ClientSecret for PKI auth.",
|
||||
"type": "string"
|
||||
},
|
||||
"client_secret": {
|
||||
"type": "string"
|
||||
},
|
||||
|
7
coderd/apidoc/swagger.json
generated
7
coderd/apidoc/swagger.json
generated
@ -7715,9 +7715,16 @@
|
||||
"auth_url_params": {
|
||||
"type": "object"
|
||||
},
|
||||
"client_cert_file": {
|
||||
"type": "string"
|
||||
},
|
||||
"client_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"client_key_file": {
|
||||
"description": "ClientKeyFile \u0026 ClientCertFile are used in place of ClientSecret for PKI auth.",
|
||||
"type": "string"
|
||||
},
|
||||
"client_secret": {
|
||||
"type": "string"
|
||||
},
|
||||
|
273
coderd/oauthpki/oidcpki.go
Normal file
273
coderd/oauthpki/oidcpki.go
Normal file
@ -0,0 +1,273 @@
|
||||
package oauthpki
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1" //#nosec // Not used for cryptography.
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/jws"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
// Config uses jwt assertions over client_secret for oauth2 authentication of
|
||||
// the application. This implementation was made specifically for Azure AD.
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/azure/active-directory/develop/certificate-credentials
|
||||
//
|
||||
// However this does mostly follow the standard. We can generalize this as we
|
||||
// include support for more IDPs.
|
||||
//
|
||||
// https://datatracker.ietf.org/doc/html/rfc7523
|
||||
type Config struct {
|
||||
cfg httpmw.OAuth2Config
|
||||
|
||||
// These values should match those provided in the oauth2.Config.
|
||||
// Because the inner config is an interface, we need to duplicate these
|
||||
// values here.
|
||||
scopes []string
|
||||
clientID string
|
||||
tokenURL string
|
||||
|
||||
// ClientSecret is the private key of the PKI cert.
|
||||
// Azure AD only supports RS256 signing algorithm.
|
||||
clientKey *rsa.PrivateKey
|
||||
// Base64url-encoded SHA-1 thumbprint of the X.509 certificate's DER encoding.
|
||||
// This is specific to Azure AD
|
||||
x5t string
|
||||
}
|
||||
|
||||
type ConfigParams struct {
|
||||
ClientID string
|
||||
TokenURL string
|
||||
Scopes []string
|
||||
PemEncodedKey []byte
|
||||
PemEncodedCert []byte
|
||||
|
||||
Config httpmw.OAuth2Config
|
||||
}
|
||||
|
||||
// NewOauth2PKIConfig creates the oauth2 config for PKI based auth. It requires the certificate and it's private key.
|
||||
// The values should be passed in as PEM encoded values, which is the standard encoding for x509 certs saved to disk.
|
||||
// It should look like:
|
||||
//
|
||||
// -----BEGIN RSA PRIVATE KEY----
|
||||
// ...
|
||||
// -----END RSA PRIVATE KEY-----
|
||||
//
|
||||
// -----BEGIN CERTIFICATE-----
|
||||
// ...
|
||||
// -----END CERTIFICATE-----
|
||||
func NewOauth2PKIConfig(params ConfigParams) (*Config, error) {
|
||||
if params.ClientID == "" {
|
||||
return nil, xerrors.Errorf("")
|
||||
}
|
||||
if len(params.Scopes) == 0 {
|
||||
return nil, xerrors.Errorf("scopes are required")
|
||||
}
|
||||
|
||||
rsaKey, err := decodeClientKey(params.PemEncodedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Azure AD requires a certificate. The sha1 of the cert is used to identify the signer.
|
||||
// This is not required in the general specification.
|
||||
if strings.Contains(strings.ToLower(params.TokenURL), "microsoftonline") && len(params.PemEncodedCert) == 0 {
|
||||
return nil, xerrors.Errorf("oidc client certificate is required and missing")
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(params.PemEncodedCert)
|
||||
// Used as an identifier, not an actual cryptographic hash.
|
||||
//nolint:gosec
|
||||
hashed := sha1.Sum(block.Bytes)
|
||||
|
||||
return &Config{
|
||||
clientID: params.ClientID,
|
||||
tokenURL: params.TokenURL,
|
||||
scopes: params.Scopes,
|
||||
cfg: params.Config,
|
||||
clientKey: rsaKey,
|
||||
x5t: base64.StdEncoding.EncodeToString(hashed[:]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// decodeClientKey decodes a PEM encoded rsa secret.
|
||||
func decodeClientKey(pemEncoded []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemEncoded)
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (ja *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
return ja.cfg.AuthCodeURL(state, opts...)
|
||||
}
|
||||
|
||||
// Exchange includes the client_assertion signed JWT.
|
||||
func (ja *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
signed, err := ja.jwtToken()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
|
||||
}
|
||||
opts = append(opts,
|
||||
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
|
||||
oauth2.SetAuthURLParam("client_assertion", signed),
|
||||
)
|
||||
return ja.cfg.Exchange(ctx, code, opts...)
|
||||
}
|
||||
|
||||
func (ja *Config) jwtToken() (string, error) {
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iss": ja.clientID,
|
||||
"sub": ja.clientID,
|
||||
"aud": ja.tokenURL,
|
||||
// 5-10 minutes is recommended in the Azure docs.
|
||||
// So we'll use 5 minutes.
|
||||
"exp": now.Add(time.Minute * 5).Unix(),
|
||||
"jti": uuid.New().String(),
|
||||
"nbf": now.Unix(),
|
||||
"iat": now.Unix(),
|
||||
})
|
||||
token.Header["x5t"] = ja.x5t
|
||||
|
||||
signed, err := token.SignedString(ja.clientKey)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("sign jwt assertion: %w", err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func (ja *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource {
|
||||
return oauth2.ReuseTokenSource(token, &jwtTokenSource{
|
||||
cfg: ja,
|
||||
ctx: ctx,
|
||||
refreshToken: token.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
type jwtTokenSource struct {
|
||||
cfg *Config
|
||||
ctx context.Context
|
||||
refreshToken string
|
||||
}
|
||||
|
||||
// Token must be safe for concurrent use by multiple go routines
|
||||
// Very similar to the RetrieveToken implementation by the oauth2 package.
|
||||
// https://github.com/golang/oauth2/blob/master/internal/token.go#L212
|
||||
// Oauth2 package keeps this code unexported or in an /internal package,
|
||||
// so we have to copy the implementation :(
|
||||
func (src *jwtTokenSource) Token() (*oauth2.Token, error) {
|
||||
if src.refreshToken == "" {
|
||||
return nil, xerrors.New("oauth2: token expired and refresh token is not set")
|
||||
}
|
||||
cli := http.DefaultClient
|
||||
if v, ok := src.ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
cli = v
|
||||
}
|
||||
|
||||
token, err := src.cfg.jwtToken()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed jwt assertion: %w", err)
|
||||
}
|
||||
|
||||
v := url.Values{
|
||||
"client_assertion": {token},
|
||||
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
|
||||
"client_id": {src.cfg.clientID},
|
||||
"grant_type": {"refresh_token"},
|
||||
"scope": {strings.Join(src.cfg.scopes, " ")},
|
||||
"refresh_token": {src.refreshToken},
|
||||
}
|
||||
// Using params based auth
|
||||
req, err := http.NewRequest("POST", src.cfg.tokenURL, strings.NewReader(v.Encode()))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("oauth2: make token refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req = req.WithContext(src.ctx)
|
||||
resp, err := cli.Do(req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("oauth2: cannot get token: %w", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("oauth2: cannot fetch token reading response body: %w", err)
|
||||
}
|
||||
|
||||
var tokenRes struct {
|
||||
oauth2.Token
|
||||
// Extra fields returned by the refresh that are needed
|
||||
IDToken string `json:"id_token"`
|
||||
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
||||
// error fields
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
ErrorCode string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
ErrorURI string `json:"error_uri"`
|
||||
}
|
||||
|
||||
unmarshalError := json.Unmarshal(body, &tokenRes)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
// Return a standard oauth2 error. Attempt to read some error fields. The error fields
|
||||
// can be encoded in a few places, so this does not catch all of them.
|
||||
return nil, &oauth2.RetrieveError{
|
||||
Response: resp,
|
||||
Body: body,
|
||||
// Best effort for error fields
|
||||
ErrorCode: tokenRes.ErrorCode,
|
||||
ErrorDescription: tokenRes.ErrorDescription,
|
||||
ErrorURI: tokenRes.ErrorURI,
|
||||
}
|
||||
}
|
||||
|
||||
if unmarshalError != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot unmarshal token: %w", err)
|
||||
}
|
||||
|
||||
newToken := &oauth2.Token{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
TokenType: tokenRes.TokenType,
|
||||
RefreshToken: tokenRes.RefreshToken,
|
||||
}
|
||||
|
||||
if secs := tokenRes.ExpiresIn; secs > 0 {
|
||||
newToken.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
||||
}
|
||||
|
||||
// ID token is a JWT token. We can decode it to get the expiry.
|
||||
// Not really sure what to do if the ExpiresIn and JWT expiry differ,
|
||||
// but this one is attached in the JWT and guaranteed to be right for local
|
||||
// validation. So use this one if found.
|
||||
if v := tokenRes.IDToken; v != "" {
|
||||
// decode returned id token to get expiry
|
||||
claimSet, err := jws.Decode(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: error decoding JWT token: %w", err)
|
||||
}
|
||||
newToken.Expiry = time.Unix(claimSet.Exp, 0)
|
||||
}
|
||||
|
||||
return newToken, nil
|
||||
}
|
296
coderd/oauthpki/okidcpki_test.go
Normal file
296
coderd/oauthpki/okidcpki_test.go
Normal file
@ -0,0 +1,296 @@
|
||||
package oauthpki_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/oauthpki"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
testClientKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAnUryZEfn5kA8wuk9a7ogFuWbk3uPHEhioYuAg9m3/tIdqSqu
|
||||
ASpRzw8+1nORTf3ykWRRlhxZWnKimmkB0Ux5Yrz9TDVWDQbzEH3B8ibMlmaNcoN8
|
||||
wYVzeEpqCe3fJagnV0lh0sHB1Z+vhcJ/M2nEAdyfhIgQEbG6Xtl2+WcGqyMWUJpV
|
||||
g8+ebK+JkXELAGN1hg3DdV52gjodEjoe1/ibHz8y3NR7j2tOKix7iKOhccyFkD35
|
||||
xqSnfyZJK5yxIfmGiWdVOIGqc2rYpgvrXJLTOjLoeyDSNi+Q604T64ZxsqfuM4LX
|
||||
BakVG3EwHFXPBfsBKjUE9HYvXEXw3fJP9K6mIwIDAQABAoIBAQCb+aH7x0IylSir
|
||||
r1Z06RDBI9bunOwBA9aqkwdRuCg4zGsVQXljNnABgACz7837JQPRIUW2MU553otX
|
||||
yyE+RzNnsjkLxSgbqvSFOe+FDOx7iB5jm/euf4NNmZ0lU3iggurgJ6iVsgVgrQUF
|
||||
AyXX+d2gawLUDYjBwxgozkSodH2sXYSX+SWfSOXHsFzSa3tLtUMbAIflM0rlRXf7
|
||||
Z57M8mMomZUvmmojH+TnBQljJlU8lhrvOaDD4DT8qAtVHE3VluDBQ9/3E8OIjz+E
|
||||
EqUgWLgrdq1rIMhJbHN90NwLwWs+2PcRfdB6hqKPktLne2KZFOgVKlxPKOYByBq1
|
||||
PX/vJ/HBAoGBAMFmJ6nYqyUVl26ajlXmnXBjQ+iBLHo9lcUu84+rpqRf90Bsm5bd
|
||||
jMmYr3Yo3yXNiit3rvZzBfPElo+IVa1HpPtgOaa2AU5B3QzxWCNT0FNRQqMG2LcA
|
||||
CvB10pOdJEABQxr7d4eFRg2/KbF1fr0r0vqMEelwa5ejTg6ROD3DtadpAoGBANA0
|
||||
4EClniCwvd1IECy2oTuTDosXgmRKwRAcwgE34YXy1Y/L4X/ghFeCHi3ybrep0uwL
|
||||
ptJNK+0sqvPu6UhC356GfMqfuzOKNMkXybnPUbHrz5KTkN+QQMfPc73Veel2gpD3
|
||||
xNataEmHtxcOx0X1OnjwyZZpmMbrUY3Cackn+durAoGBAKYR5nU+jJfnloVvSlIR
|
||||
GZhsZN++LEc7ouQTkSoJp6r2jQZRPLmrvT1PUzwPlK6NdNwmhaMy2iWc5fySgZ+u
|
||||
KcmBs3+oQi7E9+ApThnn2rfwy1vagTWDX+FkC1KeWYZsjwcYcGd61dDwGgk8b3xZ
|
||||
qW1j4e2mj31CycBQiw7eg5ohAoGADvkOe3etlHpBXS12hFCp7afYruYE6YN6uNbo
|
||||
mL/VBxX8h7fIwrJ5sfVYiENb9PdQhMsdtxf3pbnFnX875Ydxn2vag5PTGZTB0QhV
|
||||
6HfhTyM/LTJRg9JS5kuj7i3w83ojT5uR20JjMo6A+zaD3CMTjmj6hkeXxg5cMg6e
|
||||
HuoyDLsCgYBcbboYMFT1cUSxBeMtPGt3CxxZUYnUQaRUeOcjqYYlFL+DCWhY7pxH
|
||||
EnLhwW/KzkDzOmwRmmNOMqD7UhR/ayxR+avRt6v5d5l8fVCuNexgs7kR9L5IQp9l
|
||||
YV2wsCoXBCcuPmio/te44U//BlzprEu0w1iHpb3ibmQg4y291R0TvQ==
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
testClientCert = `
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEOjCCAiKgAwIBAgIQMO50KnWsRbmrrthPQgyubjANBgkqhkiG9w0BAQsFADAY
|
||||
MRYwFAYDVQQDEw1Mb2NhbGhvc3RDZXJ0MB4XDTIzMDgxMDE2MjYxOFoXDTI1MDIx
|
||||
MDE2MjU0M1owFDESMBAGA1UEAxMJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
|
||||
AAOCAQ8AMIIBCgKCAQEAnUryZEfn5kA8wuk9a7ogFuWbk3uPHEhioYuAg9m3/tId
|
||||
qSquASpRzw8+1nORTf3ykWRRlhxZWnKimmkB0Ux5Yrz9TDVWDQbzEH3B8ibMlmaN
|
||||
coN8wYVzeEpqCe3fJagnV0lh0sHB1Z+vhcJ/M2nEAdyfhIgQEbG6Xtl2+WcGqyMW
|
||||
UJpVg8+ebK+JkXELAGN1hg3DdV52gjodEjoe1/ibHz8y3NR7j2tOKix7iKOhccyF
|
||||
kD35xqSnfyZJK5yxIfmGiWdVOIGqc2rYpgvrXJLTOjLoeyDSNi+Q604T64Zxsqfu
|
||||
M4LXBakVG3EwHFXPBfsBKjUE9HYvXEXw3fJP9K6mIwIDAQABo4GDMIGAMA4GA1Ud
|
||||
DwEB/wQEAwIDuDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwHQYDVR0O
|
||||
BBYEFAYCdgydG3h2SNWF+BfAyJtNliJtMB8GA1UdIwQYMBaAFHR/aptP0RUNNFyf
|
||||
5uky527SECt1MA8GA1UdEQQIMAaHBH8AAAEwDQYJKoZIhvcNAQELBQADggIBAI6P
|
||||
ymG7l06JvJ3p6xgaMyOxgkpQl6WkY4LJHVEhfeDSoO3qsJc4PxUdSExJsT84weXb
|
||||
lF+tK6D/CPlvjmG720IlB5cSKJ71rWjwmaMWKxWKXyoZdDrHAp55+FNdXegUZF2o
|
||||
EF/ZM5CHaO8iHMkuWEv1OASHBQWC/o4spUN5HGQ9HepwLVvO/aX++LYfvfL9faKA
|
||||
IT+w9i8pJbfItFmfA8x2OEVZk8aEA0WtKdfsMwzGmZ1GSGa4UYcynxQGCMiB5h4L
|
||||
C/dpoJRbEzdGLuTZgV2SCaN3k5BrH4aaILI9tqZaq0gamN9Rd2yji3cGiduCeAAo
|
||||
RmVcl9fBliMLxylWEP5+B2JmCZEc8Lfm0TBNnjaG17KY40gzbfBYixBxBTYgsPua
|
||||
bfprtfksSG++zcsDbkC8CtPamtlNWtDAiFp4yQRkP79PlJO6qCdTrFWPukTMCMso
|
||||
25hjLvxj1fLy/jSMDEZu/oQ14TMCZSGHRjz4CPiaCfXqgqOtVOD+5+yWInwUGp/i
|
||||
Nb1vIq4ruEAbyCbdWKHbE0yT5AP7hm5ZNybpZ4/311AEBD2HKip/OqB05p99XcLw
|
||||
BIC4ODNvwCn6x00KZoqWz/MX2dEQ/HqWiWaDB/OSemfTVE3I94mzEWnqpF2cQpcT
|
||||
B1B7CpkMU55hPP+7nsofCszNrMDXT8Z5w2a3zLKM
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
)
|
||||
|
||||
// TestAzureADPKIOIDC ensures we do not break Azure AD compatibility.
|
||||
// It runs an oauth2.Exchange method and hijacks the request to only check the
|
||||
// request side of the transaction.
|
||||
func TestAzureADPKIOIDC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oauthCfg := &oauth2.Config{
|
||||
ClientID: "random-client-id",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
TokenURL: "https://login.microsoftonline.com/6a1e9139-13f2-4afb-8f46-036feac8bd79/v2.0/token",
|
||||
},
|
||||
}
|
||||
|
||||
pkiConfig, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
|
||||
ClientID: oauthCfg.ClientID,
|
||||
TokenURL: oauthCfg.Endpoint.TokenURL,
|
||||
PemEncodedKey: []byte(testClientKey),
|
||||
PemEncodedCert: []byte(testClientCert),
|
||||
Config: oauthCfg,
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
})
|
||||
require.NoError(t, err, "failed to create pki config")
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
ctx = oidc.ClientContext(ctx, &http.Client{
|
||||
Transport: &fakeRoundTripper{
|
||||
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
resp := &http.Response{
|
||||
Status: "500 Internal Service Error",
|
||||
}
|
||||
// This is the easiest way to hijack the request and check
|
||||
// the params. The oauth2 package uses unexported types and
|
||||
// options, so we need to view the actual request created.
|
||||
assertJWTAuth(t, req)
|
||||
return resp, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
_, err = pkiConfig.Exchange(ctx, base64.StdEncoding.EncodeToString([]byte("random-code")))
|
||||
// We hijack the request and return an error intentionally
|
||||
require.Error(t, err, "error expected")
|
||||
}
|
||||
|
||||
// TestSavedAzureADPKIOIDC was created by capturing actual responses from an Azure
|
||||
// AD instance and saving them to replay, removing some details.
|
||||
// The reason this is done is that this is the only way to assert values
|
||||
// passed to the oauth2 provider via http requests.
|
||||
// It is not feasible to run against an actual Azure AD instance, so this attempts
|
||||
// to prevent some regressions by running a full "e2e" oauth and asserting some
|
||||
// of the request values.
|
||||
func TestSavedAzureADPKIOIDC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
stateString = "random-state"
|
||||
oauth2Code = base64.StdEncoding.EncodeToString([]byte("random-code"))
|
||||
)
|
||||
|
||||
// Real oauth config. We will hijack all http requests so some of these values
|
||||
// are fake.
|
||||
cfg := &oauth2.Config{
|
||||
ClientID: "fake_app",
|
||||
ClientSecret: "",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/authorize",
|
||||
TokenURL: "https://login.microsoftonline.com/fake_app/oauth2/v2.0/token",
|
||||
AuthStyle: 0,
|
||||
},
|
||||
RedirectURL: "http://localhost/api/v2/users/oidc/callback",
|
||||
Scopes: []string{"openid", "profile", "email", "offline_access"},
|
||||
}
|
||||
|
||||
initialExchange := false
|
||||
tokenRefreshed := false
|
||||
|
||||
// Create the oauthpki config
|
||||
pki, err := oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
|
||||
ClientID: cfg.ClientID,
|
||||
TokenURL: cfg.Endpoint.TokenURL,
|
||||
Scopes: []string{"openid", "email", "profile", "offline_access"},
|
||||
PemEncodedKey: []byte(testClientKey),
|
||||
PemEncodedCert: []byte(testClientCert),
|
||||
Config: cfg,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var fakeCtx context.Context
|
||||
fakeClient := &http.Client{
|
||||
Transport: fakeRoundTripper{
|
||||
roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
if strings.Contains(req.URL.String(), "authorize") {
|
||||
// This is the user hitting the browser endpoint to begin the OIDC
|
||||
// auth flow.
|
||||
|
||||
// Authorize should redirect the user back to the app after authentication on
|
||||
// the IDP.
|
||||
resp := httptest.NewRecorder()
|
||||
v := url.Values{
|
||||
"code": {oauth2Code},
|
||||
"state": {stateString},
|
||||
"session_state": {"a18cf797-1e2b-4bc3-baf9-66b41a4997cf"},
|
||||
}
|
||||
|
||||
// This url doesn't really matter since the fake client will hiject this actual request.
|
||||
http.Redirect(resp, req, "http://localhost:3000/api/v2/users/oidc/callback?"+v.Encode(), http.StatusTemporaryRedirect)
|
||||
return resp.Result(), nil
|
||||
}
|
||||
if strings.Contains(req.URL.String(), "v2.0/token") {
|
||||
vals := assertJWTAuth(t, req)
|
||||
switch vals.Get("grant_type") {
|
||||
case "authorization_code":
|
||||
// Initial token
|
||||
initialExchange = true
|
||||
assert.Equal(t, oauth2Code, vals.Get("code"), "initial exchange code mismatch")
|
||||
case "refresh_token":
|
||||
// refreshed token
|
||||
tokenRefreshed = true
|
||||
assert.Equal(t, "<refresh_token_JWT>", vals.Get("refresh_token"), "refresh token required")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
// Taken from an actual response
|
||||
// Just always return a token no matter what.
|
||||
resp.Header().Set("Content-Type", "application/json")
|
||||
_, _ = resp.Write([]byte(`{
|
||||
"token_type":"Bearer",
|
||||
"scope":"email openid profile AccessReview.ReadWrite.Membership Group.Read.All Group.ReadWrite.All User.Read",
|
||||
"expires_in":4009,
|
||||
"ext_expires_in":4009,
|
||||
"access_token":"<access_token_JWT>",
|
||||
"refresh_token":"<refresh_token_JWT>",
|
||||
"id_token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ii1LSTNROW5OUjdiUm9meG1lWm9YcWJIWkdldyJ9.eyJhdWQiOiIxZjAxODMyYS1mZWViLTQyZGMtODFkOS01ZjBhYjZhMDQxZTAiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vMTEwZjBjMGYtY2Q3Ni00NzE3LWE2ZjgtNGVlYTNkMGY4MTA5L3YyLjAiLCJpYXQiOjE2OTE3OTI2MzQsIm5iZiI6MTY5MTc5MjYzNCwiZXhwIjoxNjkxNzk2NTM0LCJhaW8iOiJBWVFBZS84VUFBQUE1eEtqMmVTdWFXVmZsRlhCeGJJTnMvSkVyVHFvUGlaQW5ENmJIZWF3a2RRcisyRVRwM3RGNGY3akxicnh3ODhhVm9QOThrY0xMNjhON1hVV3FCN1I1N2JQRU9EclRlSUI1S0lyUHBjbCtIeXR0a1ljOVdWQklVVEErSllQbzl1a0ZjbGNWZ1krWUc3eHlmdi90K3Q1ZEczblNuZEdEQ1FYRVIxbDlTNko1T2c9IiwiZW1haWwiOiJzdGV2ZW5AY29kZXIuY29tIiwiZ3JvdXBzIjpbImM4MDQ4ZTkxLWY1YzMtNDdlNS05NjkzLTgzNGRlODQwMzRhZCIsIjcwYjQ4MTc1LTEwN2ItNGFkOC1iNDA1LTRkODg4YTFjNDY2ZiJdLCJpZHAiOiJtYWlsIiwibmFtZSI6IlN0ZXZlbiBNIiwib2lkIjoiN2JhNDYzNjAtZTAyNy00OTVhLTlhZTUtM2FlYWZlMzY3MGEyIiwicHJlZmVycmVkX3VzZXJuYW1lIjoic3RldmVuQGNvZGVyLmNvbSIsInByb3ZfZGF0YSI6W3siQXQiOnRydWUsIlByb3YiOiJnaXRodWIuY29tIiwiQWx0c2VjaWQiOiI1NDQ2Mjk4IiwiQWNjZXNzVG9rZW4iOm51bGx9XSwicmgiOiIwLkFUZ0FEd3dQRVhiTkYwZW0tRTdxUFEtQkNTcURBUl9yX3R4Q2dkbGZDcmFnUWVBNEFPRS4iLCJyb2xlcyI6WyJUZW1wbGF0ZUF1dGhvcnMiXSwic3ViIjoib0JTN3FjUERKdWlDMEYyQ19XdDJycVlvanhpT0o3S3JFWjlkQ1RkTGVYNCIsInRpZCI6IjExMGYwYzBmLWNkNzYtNDcxNy1hNmY4LTRlZWEzZDBmODEwOSIsInV0aSI6IktReGlIWGtaZUVxcC1tQWlVdTlyQUEiLCJ2ZXIiOiIyLjAiLCJyb2xlczIiOiJUZW1wbGF0ZUF1dGhvcnMifQ.JevFI4Xm9dW7kQq4xEgZnUaU0SqbeOAFtT0YIKQNefR9Db4sjxCaKRmX0pPt-CM9j45d6fAiAkLFDAqjlSbi4Zi0GbEomT3yegmuxKgEgjPpJlGjF2TBUpsNNyn5gJ9Wkct9BfwALJhX2ePJFzIlkvx9opNNbNK1qHKMMjOSRFG6AGExKRDiQAME0a4hVgCwrAdUs4JrCcj4LqB84dODN-eoh-jx2-1wDvf6fovfwLHDQwjY4lfBxaYdNavKM369hrhU-U067rSnCzvDD26f4VLhPF52hiQIbTVN5t7p_1XmcduUiaNnmr9AZiZxZ-94mctSRRR8xG0pNwO2yv84iA"
|
||||
}`))
|
||||
return resp.Result(), nil
|
||||
}
|
||||
// This is the "Coder" half of things. We can keep this in the fake
|
||||
// client, essentially being the fake client on both sides of the OIDC
|
||||
// flow.
|
||||
if strings.Contains(req.URL.String(), "v2/users/oidc/callback") {
|
||||
// This is the callback from the IDP.
|
||||
code := req.URL.Query().Get("code")
|
||||
require.Equal(t, oauth2Code, code, "code mismatch")
|
||||
state := req.URL.Query().Get("state")
|
||||
require.Equal(t, stateString, state, "state mismatch")
|
||||
|
||||
// Exchange for token should work
|
||||
token, err := pki.Exchange(fakeCtx, code)
|
||||
if !assert.NoError(t, err) {
|
||||
return httptest.NewRecorder().Result(), nil
|
||||
}
|
||||
|
||||
// Also try a refresh
|
||||
cpy := token
|
||||
cpy.Expiry = time.Now().Add(time.Minute * -1)
|
||||
src := pki.TokenSource(fakeCtx, cpy)
|
||||
_, err = src.Token()
|
||||
tokenRefreshed = true
|
||||
assert.NoError(t, err, "token refreshed")
|
||||
return httptest.NewRecorder().Result(), nil
|
||||
}
|
||||
|
||||
return nil, xerrors.Errorf("not implemented")
|
||||
},
|
||||
},
|
||||
}
|
||||
fakeCtx = oidc.ClientContext(context.Background(), fakeClient)
|
||||
_ = fakeCtx
|
||||
|
||||
// This simulates a client logging into the browser. The 307 redirect will
|
||||
// make sure this goes through the full flow.
|
||||
// nolint: noctx
|
||||
resp, err := fakeClient.Get(pki.AuthCodeURL("state", oauth2.AccessTypeOffline))
|
||||
require.NoError(t, err)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
require.True(t, initialExchange, "initial token exchange complete")
|
||||
require.True(t, tokenRefreshed, "token was refreshed")
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct {
|
||||
roundTrip func(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (f fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f.roundTrip(req)
|
||||
}
|
||||
|
||||
// assertJWTAuth will assert the basic JWT auth assertions. It will return the
|
||||
// url.Values from the request body for any additional assertions to be made.
|
||||
func assertJWTAuth(t *testing.T, r *http.Request) url.Values {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if !assert.NoError(t, err) {
|
||||
return nil
|
||||
}
|
||||
vals, err := url.ParseQuery(string(body))
|
||||
if !assert.NoError(t, err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", vals.Get("client_assertion_type"))
|
||||
jwtToken := vals.Get("client_assertion")
|
||||
// No need to actually verify the jwt is signed right.
|
||||
parsedToken, _, err := (&jwt.Parser{}).ParseUnverified(jwtToken, jwt.MapClaims{})
|
||||
if !assert.NoError(t, err, "failed to parse jwt token") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Azure requirements
|
||||
assert.NotEmpty(t, parsedToken.Header["x5t"], "hashed cert missing")
|
||||
assert.Equal(t, "RS256", parsedToken.Header["alg"], "azure only accepts RS256")
|
||||
assert.Equal(t, "JWT", parsedToken.Header["typ"], "azure only accepts JWT")
|
||||
|
||||
return vals
|
||||
}
|
Reference in New Issue
Block a user