feat: allow storing extra oauth token properties in the database (#10152)

This commit is contained in:
Kyle Carberry
2023-10-09 18:49:30 -05:00
committed by GitHub
parent 35538e1051
commit 863c2e7b64
25 changed files with 223 additions and 60 deletions

6
coderd/apidoc/docs.go generated
View File

@ -8381,6 +8381,12 @@ const docTemplate = `{
"description": "DisplayName is shown in the UI to identify the auth config.",
"type": "string"
},
"extra_token_keys": {
"type": "array",
"items": {
"type": "string"
}
},
"id": {
"description": "ID is a unique identifier for the auth config.\nIt defaults to ` + "`" + `type` + "`" + ` when not provided.",
"type": "string"

View File

@ -7515,6 +7515,12 @@
"description": "DisplayName is shown in the UI to identify the auth config.",
"type": "string"
},
"extra_token_keys": {
"type": "array",
"items": {
"type": "string"
}
},
"id": {
"description": "ID is a unique identifier for the auth config.\nIt defaults to `type` when not provided.",
"type": "string"

View File

@ -68,6 +68,7 @@ type FakeIDP struct {
// "Authorized Redirect URLs". This can be used to emulate that.
hookValidRedirectURL func(redirectURL string) error
hookUserInfo func(email string) (jwt.MapClaims, error)
hookMutateToken func(token map[string]interface{})
fakeCoderd func(req *http.Request) (*http.Response, error)
hookOnRefresh func(email string) error
// Custom authentication for the client. This is useful if you want
@ -112,6 +113,14 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
}
}
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
// These extra fields can override the default fields (id_token, access_token, etc).
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookMutateToken = mutateToken
}
}
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
return func(f *FakeIDP) {
f.hookAuthenticateClient = hook
@ -621,6 +630,9 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
"expires_in": int64((time.Minute * 5).Seconds()),
"id_token": f.encodeClaims(t, claims),
}
if f.hookMutateToken != nil {
f.hookMutateToken(token)
}
// Store the claims for the next refresh
f.refreshIDTokenClaims.Store(refreshToken, claims)

View File

@ -4246,6 +4246,7 @@ func (q *FakeQuerier) InsertExternalAuthLink(_ context.Context, arg database.Ins
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
OAuthExpiry: arg.OAuthExpiry,
OAuthExtra: arg.OAuthExtra,
}
q.externalAuthLinks = append(q.externalAuthLinks, gitAuthLink)
return gitAuthLink, nil
@ -5301,6 +5302,7 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
gitAuthLink.OAuthExtra = arg.OAuthExtra
q.externalAuthLinks[index] = gitAuthLink
return gitAuthLink, nil

View File

@ -514,6 +514,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
}
func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAuthLink) database.ExternalAuthLink {
msg := takeFirst(&orig.OAuthExtra, &pqtype.NullRawMessage{})
link, err := db.InsertExternalAuthLink(genCtx, database.InsertExternalAuthLinkParams{
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
UserID: takeFirst(orig.UserID, uuid.New()),
@ -524,6 +525,7 @@ func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAut
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
OAuthExtra: *msg,
})
require.NoError(t, err, "insert external auth link")

View File

@ -359,7 +359,8 @@ CREATE TABLE external_auth_links (
oauth_refresh_token text NOT NULL,
oauth_expiry timestamp with time zone NOT NULL,
oauth_access_token_key_id text,
oauth_refresh_token_key_id text
oauth_refresh_token_key_id text,
oauth_extra jsonb
);
COMMENT ON COLUMN external_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';

View File

@ -0,0 +1 @@
ALTER TABLE external_auth_links DROP COLUMN "oauth_extra";

View File

@ -0,0 +1 @@
ALTER TABLE external_auth_links ADD COLUMN "oauth_extra" jsonb;

View File

@ -1680,7 +1680,8 @@ type ExternalAuthLink struct {
// The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
}
type File struct {

View File

@ -751,7 +751,7 @@ func (q *sqlQuerier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest strin
}
const getExternalAuthLink = `-- name: GetExternalAuthLink :one
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE provider_id = $1 AND user_id = $2
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE provider_id = $1 AND user_id = $2
`
type GetExternalAuthLinkParams struct {
@ -772,12 +772,13 @@ func (q *sqlQuerier) GetExternalAuthLink(ctx context.Context, arg GetExternalAut
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
)
return i, err
}
const getExternalAuthLinksByUserID = `-- name: GetExternalAuthLinksByUserID :many
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE user_id = $1
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE user_id = $1
`
func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error) {
@ -799,6 +800,7 @@ func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uu
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
); err != nil {
return nil, err
}
@ -823,7 +825,8 @@ INSERT INTO external_auth_links (
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
oauth_expiry,
oauth_extra
) VALUES (
$1,
$2,
@ -833,20 +836,22 @@ INSERT INTO external_auth_links (
$6,
$7,
$8,
$9
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
$9,
$10
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra
`
type InsertExternalAuthLinkParams struct {
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
}
func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExternalAuthLinkParams) (ExternalAuthLink, error) {
@ -860,6 +865,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
arg.OAuthExtra,
)
var i ExternalAuthLink
err := row.Scan(
@ -872,6 +878,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
)
return i, err
}
@ -883,19 +890,21 @@ UPDATE external_auth_links SET
oauth_access_token_key_id = $5,
oauth_refresh_token = $6,
oauth_refresh_token_key_id = $7,
oauth_expiry = $8
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
oauth_expiry = $8,
oauth_extra = $9
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra
`
type UpdateExternalAuthLinkParams struct {
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
ProviderID string `db:"provider_id" json:"provider_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
}
func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) {
@ -908,6 +917,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter
arg.OAuthRefreshToken,
arg.OAuthRefreshTokenKeyID,
arg.OAuthExpiry,
arg.OAuthExtra,
)
var i ExternalAuthLink
err := row.Scan(
@ -920,6 +930,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter
&i.OAuthExpiry,
&i.OAuthAccessTokenKeyID,
&i.OAuthRefreshTokenKeyID,
&i.OAuthExtra,
)
return i, err
}

View File

@ -14,7 +14,8 @@ INSERT INTO external_auth_links (
oauth_access_token_key_id,
oauth_refresh_token,
oauth_refresh_token_key_id,
oauth_expiry
oauth_expiry,
oauth_extra
) VALUES (
$1,
$2,
@ -24,7 +25,8 @@ INSERT INTO external_auth_links (
$6,
$7,
$8,
$9
$9,
$10
) RETURNING *;
-- name: UpdateExternalAuthLink :one
@ -34,5 +36,6 @@ UPDATE external_auth_links SET
oauth_access_token_key_id = $5,
oauth_refresh_token = $6,
oauth_refresh_token_key_id = $7,
oauth_expiry = $8
oauth_expiry = $8,
oauth_extra = $9
WHERE provider_id = $1 AND user_id = $2 RETURNING *;

View File

@ -53,6 +53,7 @@ overrides:
oauth_id_token: OAuthIDToken
oauth_refresh_token: OAuthRefreshToken
oauth_refresh_token_key_id: OAuthRefreshTokenKeyID
oauth_extra: OAuthExtra
parameter_type_system_hcl: ParameterTypeSystemHCL
userstatus: UserStatus
gitsshkey: GitSSHKey

View File

@ -14,6 +14,7 @@ import (
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/sqlc-dev/pqtype"
)
// @Summary Get external auth by ID
@ -132,6 +133,8 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthExpiry: token.Expiry,
// No extra data from device auth!
OAuthExtra: pqtype.NullRawMessage{},
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -150,6 +153,7 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
OAuthExtra: pqtype.NullRawMessage{},
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -201,7 +205,15 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
apiKey = httpmw.APIKey(r)
)
_, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
extra, err := externalAuthConfig.GenerateTokenExtra(state.Token)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to generate token extra.",
Detail: err.Error(),
})
return
}
_, err = api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: externalAuthConfig.ID,
UserID: apiKey.UserID,
})
@ -224,6 +236,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
OAuthRefreshToken: state.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
OAuthExpiry: state.Token.Expiry,
OAuthExtra: extra,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -242,6 +255,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
OAuthRefreshToken: state.Token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: state.Token.Expiry,
OAuthExtra: extra,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{

View File

@ -15,6 +15,7 @@ import (
"golang.org/x/xerrors"
"github.com/google/go-github/v43/github"
"github.com/sqlc-dev/pqtype"
xgithub "golang.org/x/oauth2/github"
"github.com/coder/coder/v2/coderd/database"
@ -44,6 +45,14 @@ type Config struct {
// DisplayIcon is the path to an image that will be displayed to the user.
DisplayIcon string
// ExtraTokenKeys is a list of extra properties to
// store in the database returned from the token endpoint.
//
// e.g. Slack returns `authed_user` in the token which is
// a payload that contains information about the authenticated
// user.
ExtraTokenKeys []string
// NoRefresh stops Coder from using the refresh token
// to renew the access token.
//
@ -69,6 +78,25 @@ type Config struct {
AppInstallationsURL string
}
// GenerateTokenExtra generates the extra token data to store in the database.
func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) {
if len(c.ExtraTokenKeys) == 0 {
return pqtype.NullRawMessage{}, nil
}
extraMap := map[string]interface{}{}
for _, key := range c.ExtraTokenKeys {
extraMap[key] = token.Extra(key)
}
data, err := json.Marshal(extraMap)
if err != nil {
return pqtype.NullRawMessage{}, err
}
return pqtype.NullRawMessage{
RawMessage: data,
Valid: true,
}, nil
}
// RefreshToken automatically refreshes the token if expired and permitted.
// It returns the token and a bool indicating if the token is valid.
func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, bool, error) {
@ -101,6 +129,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
// we aren't trying to surface an error, we're just trying to obtain a valid token.
return externalAuthLink, false, nil
}
extra, err := c.GenerateTokenExtra(token)
if err != nil {
return externalAuthLink, false, xerrors.Errorf("generate token extra: %w", err)
}
r := retry.New(50*time.Millisecond, 200*time.Millisecond)
// See the comment below why the retry and cancel is required.
retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second)
@ -135,6 +169,7 @@ validate:
OAuthRefreshToken: token.RefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: token.Expiry,
OAuthExtra: extra,
})
if err != nil {
return updatedAuthLink, false, xerrors.Errorf("update external auth link: %w", err)
@ -539,6 +574,14 @@ var defaults = map[codersdk.EnhancedExternalAuthProvider]codersdk.ExternalAuthCo
DeviceCodeURL: "https://github.com/login/device/code",
AppInstallationsURL: "https://api.github.com/user/installations",
},
codersdk.EnhancedExternalAuthProviderSlack: {
AuthURL: "https://slack.com/oauth/v2/authorize",
TokenURL: "https://slack.com/api/oauth.v2.access",
DisplayName: "Slack",
DisplayIcon: "/icon/slack.svg",
// See: https://api.slack.com/authentication/oauth-v2#exchanging
ExtraTokenKeys: []string{"authed_user"},
},
}
// jwtConfig is a new OAuth2 config that uses a custom

View File

@ -2,6 +2,7 @@ package externalauth_test
import (
"context"
"encoding/json"
"net/http"
"net/url"
"testing"
@ -43,7 +44,7 @@ func TestRefreshToken(t *testing.T) {
return nil, xerrors.New("should not be called")
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.NoRefresh = true
},
})
@ -74,7 +75,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, nil
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.NoRefresh = true
},
})
@ -117,7 +118,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, xerrors.New(staticError)
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
},
})
@ -142,7 +143,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
},
})
@ -175,7 +176,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
})
@ -205,7 +206,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, nil
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
})
@ -236,7 +237,7 @@ func TestRefreshToken(t *testing.T) {
return jwt.MapClaims{}, nil
}),
},
GitConfigOpt: func(cfg *externalauth.Config) {
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
},
DB: db,
@ -260,6 +261,41 @@ func TestRefreshToken(t *testing.T) {
require.NoError(t, err)
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
})
t.Run("WithExtra", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
fake, config, link := setupOauth2Test(t, testConfig{
FakeIDPOpts: []oidctest.FakeIDPOpt{
oidctest.WithMutateToken(func(token map[string]interface{}) {
token["authed_user"] = map[string]interface{}{
"access_token": token["access_token"],
}
}),
},
ExternalAuthOpt: func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderSlack.String()
cfg.ExtraTokenKeys = []string{"authed_user"}
cfg.ValidateURL = ""
},
DB: db,
})
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
// Force a refresh
link.OAuthExpiry = expired
updated, ok, err := config.RefreshToken(ctx, db, link)
require.NoError(t, err)
require.True(t, ok)
require.True(t, updated.OAuthExtra.Valid)
extra := map[string]interface{}{}
require.NoError(t, json.Unmarshal(updated.OAuthExtra.RawMessage, &extra))
mapping, ok := extra["authed_user"].(map[string]interface{})
require.True(t, ok)
require.Equal(t, updated.OAuthAccessToken, mapping["access_token"])
})
}
func TestConvertYAML(t *testing.T) {
@ -344,7 +380,7 @@ func TestConvertYAML(t *testing.T) {
type testConfig struct {
FakeIDPOpts []oidctest.FakeIDPOpt
CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig)
GitConfigOpt func(cfg *externalauth.Config)
ExternalAuthOpt func(cfg *externalauth.Config)
// If DB is passed in, the link will be inserted into the DB.
DB database.Store
}
@ -367,7 +403,7 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext
ID: providerID,
ValidateURL: fake.WellknownConfig().UserInfoURL,
}
settings.GitConfigOpt(config)
settings.ExternalAuthOpt(config)
oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
"email": "test@coder.com",