mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore: implement device auth flow for fake idp (#11707)
* chore: implement device auth flow for fake idp
This commit is contained in:
@ -10,11 +10,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -34,9 +37,11 @@ import (
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
"github.com/coder/coder/v2/coderd/util/syncmap"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
type token struct {
|
||||
@ -45,6 +50,13 @@ type token struct {
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
type deviceFlow struct {
|
||||
// userInput is the expected input to authenticate the device flow.
|
||||
userInput string
|
||||
exp time.Time
|
||||
granted bool
|
||||
}
|
||||
|
||||
// FakeIDP is a functional OIDC provider.
|
||||
// It only supports 1 OIDC client.
|
||||
type FakeIDP struct {
|
||||
@ -77,6 +89,8 @@ type FakeIDP struct {
|
||||
refreshTokens *syncmap.Map[string, string]
|
||||
stateToIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||
refreshIDTokenClaims *syncmap.Map[string, jwt.MapClaims]
|
||||
// Device flow
|
||||
deviceCode *syncmap.Map[string, deviceFlow]
|
||||
|
||||
// hooks
|
||||
// hookValidRedirectURL can be used to reject a redirect url from the
|
||||
@ -226,6 +240,8 @@ const (
|
||||
authorizePath = "/oauth2/authorize"
|
||||
keysPath = "/oauth2/keys"
|
||||
userInfoPath = "/oauth2/userinfo"
|
||||
deviceAuth = "/login/device/code"
|
||||
deviceVerify = "/login/device"
|
||||
)
|
||||
|
||||
func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
@ -246,6 +262,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
|
||||
refreshTokensUsed: syncmap.New[string, bool](),
|
||||
stateToIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
refreshIDTokenClaims: syncmap.New[string, jwt.MapClaims](),
|
||||
deviceCode: syncmap.New[string, deviceFlow](),
|
||||
hookOnRefresh: func(_ string) error { return nil },
|
||||
hookUserInfo: func(email string) (jwt.MapClaims, error) { return jwt.MapClaims{}, nil },
|
||||
hookValidRedirectURL: func(redirectURL string) error { return nil },
|
||||
@ -288,11 +305,12 @@ func (f *FakeIDP) updateIssuerURL(t testing.TB, issuer string) {
|
||||
// ProviderJSON is the JSON representation of the OpenID Connect provider
|
||||
// These are all the urls that the IDP will respond to.
|
||||
f.provider = ProviderJSON{
|
||||
Issuer: issuer,
|
||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
|
||||
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
|
||||
Issuer: issuer,
|
||||
AuthURL: u.ResolveReference(&url.URL{Path: authorizePath}).String(),
|
||||
TokenURL: u.ResolveReference(&url.URL{Path: tokenPath}).String(),
|
||||
JWKSURL: u.ResolveReference(&url.URL{Path: keysPath}).String(),
|
||||
UserInfoURL: u.ResolveReference(&url.URL{Path: userInfoPath}).String(),
|
||||
DeviceCodeURL: u.ResolveReference(&url.URL{Path: deviceAuth}).String(),
|
||||
Algorithms: []string{
|
||||
"RS256",
|
||||
},
|
||||
@ -467,6 +485,31 @@ func (f *FakeIDP) ExternalLogin(t testing.TB, client *codersdk.Client, opts ...f
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
|
||||
// DeviceLogin does the oauth2 device flow for external auth providers.
|
||||
func (*FakeIDP) DeviceLogin(t testing.TB, client *codersdk.Client, externalAuthID string) {
|
||||
// First we need to initiate the device flow. This will have Coder hit the
|
||||
// fake IDP and get a device code.
|
||||
device, err := client.ExternalAuthDeviceByID(context.Background(), externalAuthID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the user needs to go to the fake IDP page and click "allow" and enter
|
||||
// the device code input. For our purposes, we just send an http request to
|
||||
// the verification url. No additional user input is needed.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
resp, err := client.Request(ctx, http.MethodPost, device.VerificationURI, nil)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Now we need to exchange the device code for an access token. We do this
|
||||
// in this method because it is the user that does the polling for the device
|
||||
// auth flow, not the backend.
|
||||
err = client.ExternalAuthDeviceExchange(context.Background(), externalAuthID, codersdk.ExternalAuthDeviceExchange{
|
||||
DeviceCode: device.DeviceCode,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// CreateAuthCode emulates a user clicking "allow" on the IDP page. When doing
|
||||
// unit tests, it's easier to skip this step sometimes. It does make an actual
|
||||
// request to the IDP, so it should be equivalent to doing this "manually" with
|
||||
@ -536,12 +579,13 @@ func (f *FakeIDP) OIDCCallback(t testing.TB, state string, idTokenClaims jwt.Map
|
||||
|
||||
// ProviderJSON is the .well-known/configuration JSON
|
||||
type ProviderJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
DeviceCodeURL string `json:"device_authorization_endpoint"`
|
||||
Algorithms []string `json:"id_token_signing_alg_values_supported"`
|
||||
// This is custom
|
||||
ExternalAuthURL string `json:"external_auth_url"`
|
||||
}
|
||||
@ -709,8 +753,15 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
}))
|
||||
|
||||
mux.Handle(tokenPath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
values, err := f.authenticateOIDCClientRequest(t, r)
|
||||
var values url.Values
|
||||
var err error
|
||||
if r.URL.Query().Get("grant_type") == "urn:ietf:params:oauth:grant-type:device_code" {
|
||||
values = r.URL.Query()
|
||||
} else {
|
||||
values, err = f.authenticateOIDCClientRequest(t, r)
|
||||
}
|
||||
f.logger.Info(r.Context(), "http idp call token",
|
||||
slog.F("url", r.URL.String()),
|
||||
slog.F("valid", err == nil),
|
||||
slog.F("grant_type", values.Get("grant_type")),
|
||||
slog.F("values", values.Encode()),
|
||||
@ -784,6 +835,37 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
f.refreshTokensUsed.Store(refreshToken, true)
|
||||
// Always invalidate the refresh token after it is used.
|
||||
f.refreshTokens.Delete(refreshToken)
|
||||
case "urn:ietf:params:oauth:grant-type:device_code":
|
||||
// Device flow
|
||||
var resp externalauth.ExchangeDeviceCodeResponse
|
||||
deviceCode := values.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
resp.Error = "invalid_request"
|
||||
resp.ErrorDescription = "missing device_code"
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
|
||||
return
|
||||
}
|
||||
|
||||
deviceFlow, ok := f.deviceCode.Load(deviceCode)
|
||||
if !ok {
|
||||
resp.Error = "invalid_request"
|
||||
resp.ErrorDescription = "device_code provided not found"
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, resp)
|
||||
return
|
||||
}
|
||||
|
||||
if !deviceFlow.granted {
|
||||
// Status code ok with the error as pending.
|
||||
resp.Error = "authorization_pending"
|
||||
resp.ErrorDescription = ""
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
// Would be nice to get an actual email here.
|
||||
claims = jwt.MapClaims{
|
||||
"email": "unknown-dev-auth",
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected grant_type %q", values.Get("grant_type"))
|
||||
http.Error(rw, "invalid grant_type", http.StatusBadRequest)
|
||||
@ -807,8 +889,30 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
// Store the claims for the next refresh
|
||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(rw).Encode(token)
|
||||
mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept"))
|
||||
if mediaType == "application/x-www-form-urlencoded" {
|
||||
// This val encode might not work for some data structures.
|
||||
// It's good enough for now...
|
||||
rw.Header().Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
vals := url.Values{}
|
||||
for k, v := range token {
|
||||
vals.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
_, _ = rw.Write([]byte(vals.Encode()))
|
||||
return
|
||||
}
|
||||
// Default to json since the oauth2 package doesn't use Accept headers.
|
||||
if mediaType == "application/json" || mediaType == "" {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(rw).Encode(token)
|
||||
return
|
||||
}
|
||||
|
||||
// If we get something we don't support, throw an error.
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "'Accept' header contains unsupported media type",
|
||||
Detail: fmt.Sprintf("Found %q", mediaType),
|
||||
})
|
||||
}))
|
||||
|
||||
validateMW := func(rw http.ResponseWriter, r *http.Request) (email string, ok bool) {
|
||||
@ -886,6 +990,125 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
_ = json.NewEncoder(rw).Encode(set)
|
||||
}))
|
||||
|
||||
mux.Handle(deviceVerify, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http call device verify")
|
||||
|
||||
inputParam := "user_input"
|
||||
userInput := r.URL.Query().Get(inputParam)
|
||||
if userInput == "" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid user input",
|
||||
Detail: fmt.Sprintf("Hit this url again with ?%s=<user_code>", inputParam),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := r.URL.Query().Get("device_code")
|
||||
if deviceCode == "" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid device code",
|
||||
Detail: "Hit this url again with ?device_code=<device_code>",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
flow, ok := f.deviceCode.Load(deviceCode)
|
||||
if !ok {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid device code",
|
||||
Detail: "Device code not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if time.Now().After(flow.exp) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid device code",
|
||||
Detail: "Device code expired.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(flow.userInput) != strings.TrimSpace(userInput) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid device code",
|
||||
Detail: "user code does not match",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
f.deviceCode.Store(deviceCode, deviceFlow{
|
||||
userInput: flow.userInput,
|
||||
exp: flow.exp,
|
||||
granted: true,
|
||||
})
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
|
||||
Message: "Device authenticated!",
|
||||
})
|
||||
}))
|
||||
|
||||
mux.Handle(deviceAuth, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Info(r.Context(), "http call device auth")
|
||||
|
||||
p := httpapi.NewQueryParamParser()
|
||||
p.Required("client_id")
|
||||
clientID := p.String(r.URL.Query(), "", "client_id")
|
||||
_ = p.String(r.URL.Query(), "", "scopes")
|
||||
if len(p.Errors) > 0 {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid query params",
|
||||
Validations: p.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if clientID != f.clientID {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid client id",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := uuid.NewString()
|
||||
lifetime := time.Second * 900
|
||||
flow := deviceFlow{
|
||||
//nolint:gosec
|
||||
userInput: fmt.Sprintf("%d", rand.Intn(9999999)+1e8),
|
||||
}
|
||||
f.deviceCode.Store(deviceCode, deviceFlow{
|
||||
userInput: flow.userInput,
|
||||
exp: time.Now().Add(lifetime),
|
||||
})
|
||||
|
||||
verifyURL := f.issuerURL.ResolveReference(&url.URL{
|
||||
Path: deviceVerify,
|
||||
RawQuery: url.Values{
|
||||
"device_code": {deviceCode},
|
||||
"user_input": {flow.userInput},
|
||||
}.Encode(),
|
||||
}).String()
|
||||
|
||||
if mediaType, _, _ := mime.ParseMediaType(r.Header.Get("Accept")); mediaType == "application/json" {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, map[string]any{
|
||||
"device_code": deviceCode,
|
||||
"user_code": flow.userInput,
|
||||
"verification_uri": verifyURL,
|
||||
"expires_in": int(lifetime.Seconds()),
|
||||
"interval": 3,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// By default, GitHub form encodes these.
|
||||
_, _ = fmt.Fprint(rw, url.Values{
|
||||
"device_code": {deviceCode},
|
||||
"user_code": {flow.userInput},
|
||||
"verification_uri": {verifyURL},
|
||||
"expires_in": {strconv.Itoa(int(lifetime.Seconds()))},
|
||||
"interval": {"3"},
|
||||
}.Encode())
|
||||
}))
|
||||
|
||||
mux.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
||||
f.logger.Error(r.Context(), "http call not found", slog.F("path", r.URL.Path))
|
||||
t.Errorf("unexpected request to IDP at path %q. Not supported", r.URL.Path)
|
||||
@ -987,6 +1210,8 @@ type ExternalAuthConfigOptions struct {
|
||||
// completely customize the response. It captures all routes under the /external-auth-validate/*
|
||||
// so the caller can do whatever they want and even add routes.
|
||||
routes map[string]func(email string, rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
UseDeviceAuth bool
|
||||
}
|
||||
|
||||
func (o *ExternalAuthConfigOptions) AddRoute(route string, handle func(email string, rw http.ResponseWriter, r *http.Request)) *ExternalAuthConfigOptions {
|
||||
@ -1033,9 +1258,10 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||
}
|
||||
}
|
||||
instrumentF := promoauth.NewFactory(prometheus.NewRegistry())
|
||||
oauthCfg := instrumentF.New(f.clientID, f.OIDCConfig(t, nil))
|
||||
cfg := &externalauth.Config{
|
||||
DisplayName: id,
|
||||
InstrumentedOAuth2Config: instrumentF.New(f.clientID, f.OIDCConfig(t, nil)),
|
||||
InstrumentedOAuth2Config: oauthCfg,
|
||||
ID: id,
|
||||
// No defaults for these fields by omitting the type
|
||||
Type: "",
|
||||
@ -1043,7 +1269,19 @@ func (f *FakeIDP) ExternalAuthConfig(t testing.TB, id string, custom *ExternalAu
|
||||
// Omit the /user for the validate so we can easily append to it when modifying
|
||||
// the cfg for advanced tests.
|
||||
ValidateURL: f.issuerURL.ResolveReference(&url.URL{Path: "/external-auth-validate/"}).String(),
|
||||
DeviceAuth: &externalauth.DeviceAuth{
|
||||
Config: oauthCfg,
|
||||
ClientID: f.clientID,
|
||||
TokenURL: f.provider.TokenURL,
|
||||
Scopes: []string{},
|
||||
CodeURL: f.provider.DeviceCodeURL,
|
||||
},
|
||||
}
|
||||
|
||||
if !custom.UseDeviceAuth {
|
||||
cfg.DeviceAuth = nil
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
|
Reference in New Issue
Block a user