mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
feat: allow storing extra oauth token properties in the database (#10152)
This commit is contained in:
6
coderd/apidoc/docs.go
generated
6
coderd/apidoc/docs.go
generated
@ -8381,6 +8381,12 @@ const docTemplate = `{
|
||||
"description": "DisplayName is shown in the UI to identify the auth config.",
|
||||
"type": "string"
|
||||
},
|
||||
"extra_token_keys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"id": {
|
||||
"description": "ID is a unique identifier for the auth config.\nIt defaults to ` + "`" + `type` + "`" + ` when not provided.",
|
||||
"type": "string"
|
||||
|
6
coderd/apidoc/swagger.json
generated
6
coderd/apidoc/swagger.json
generated
@ -7515,6 +7515,12 @@
|
||||
"description": "DisplayName is shown in the UI to identify the auth config.",
|
||||
"type": "string"
|
||||
},
|
||||
"extra_token_keys": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"id": {
|
||||
"description": "ID is a unique identifier for the auth config.\nIt defaults to `type` when not provided.",
|
||||
"type": "string"
|
||||
|
@ -68,6 +68,7 @@ type FakeIDP struct {
|
||||
// "Authorized Redirect URLs". This can be used to emulate that.
|
||||
hookValidRedirectURL func(redirectURL string) error
|
||||
hookUserInfo func(email string) (jwt.MapClaims, error)
|
||||
hookMutateToken func(token map[string]interface{})
|
||||
fakeCoderd func(req *http.Request) (*http.Response, error)
|
||||
hookOnRefresh func(email string) error
|
||||
// Custom authentication for the client. This is useful if you want
|
||||
@ -112,6 +113,14 @@ func WithRefresh(hook func(email string) error) func(*FakeIDP) {
|
||||
}
|
||||
}
|
||||
|
||||
// WithExtra returns extra fields that be accessed on the returned Oauth Token.
|
||||
// These extra fields can override the default fields (id_token, access_token, etc).
|
||||
func WithMutateToken(mutateToken func(token map[string]interface{})) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookMutateToken = mutateToken
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomClientAuth(hook func(t testing.TB, req *http.Request) (url.Values, error)) func(*FakeIDP) {
|
||||
return func(f *FakeIDP) {
|
||||
f.hookAuthenticateClient = hook
|
||||
@ -621,6 +630,9 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
|
||||
"expires_in": int64((time.Minute * 5).Seconds()),
|
||||
"id_token": f.encodeClaims(t, claims),
|
||||
}
|
||||
if f.hookMutateToken != nil {
|
||||
f.hookMutateToken(token)
|
||||
}
|
||||
// Store the claims for the next refresh
|
||||
f.refreshIDTokenClaims.Store(refreshToken, claims)
|
||||
|
||||
|
@ -4246,6 +4246,7 @@ func (q *FakeQuerier) InsertExternalAuthLink(_ context.Context, arg database.Ins
|
||||
OAuthRefreshToken: arg.OAuthRefreshToken,
|
||||
OAuthRefreshTokenKeyID: arg.OAuthRefreshTokenKeyID,
|
||||
OAuthExpiry: arg.OAuthExpiry,
|
||||
OAuthExtra: arg.OAuthExtra,
|
||||
}
|
||||
q.externalAuthLinks = append(q.externalAuthLinks, gitAuthLink)
|
||||
return gitAuthLink, nil
|
||||
@ -5301,6 +5302,7 @@ func (q *FakeQuerier) UpdateExternalAuthLink(_ context.Context, arg database.Upd
|
||||
gitAuthLink.OAuthRefreshToken = arg.OAuthRefreshToken
|
||||
gitAuthLink.OAuthRefreshTokenKeyID = arg.OAuthRefreshTokenKeyID
|
||||
gitAuthLink.OAuthExpiry = arg.OAuthExpiry
|
||||
gitAuthLink.OAuthExtra = arg.OAuthExtra
|
||||
q.externalAuthLinks[index] = gitAuthLink
|
||||
|
||||
return gitAuthLink, nil
|
||||
|
@ -514,6 +514,7 @@ func UserLink(t testing.TB, db database.Store, orig database.UserLink) database.
|
||||
}
|
||||
|
||||
func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAuthLink) database.ExternalAuthLink {
|
||||
msg := takeFirst(&orig.OAuthExtra, &pqtype.NullRawMessage{})
|
||||
link, err := db.InsertExternalAuthLink(genCtx, database.InsertExternalAuthLinkParams{
|
||||
ProviderID: takeFirst(orig.ProviderID, uuid.New().String()),
|
||||
UserID: takeFirst(orig.UserID, uuid.New()),
|
||||
@ -524,6 +525,7 @@ func ExternalAuthLink(t testing.TB, db database.Store, orig database.ExternalAut
|
||||
OAuthExpiry: takeFirst(orig.OAuthExpiry, dbtime.Now().Add(time.Hour*24)),
|
||||
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
|
||||
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
|
||||
OAuthExtra: *msg,
|
||||
})
|
||||
|
||||
require.NoError(t, err, "insert external auth link")
|
||||
|
3
coderd/database/dump.sql
generated
3
coderd/database/dump.sql
generated
@ -359,7 +359,8 @@ CREATE TABLE external_auth_links (
|
||||
oauth_refresh_token text NOT NULL,
|
||||
oauth_expiry timestamp with time zone NOT NULL,
|
||||
oauth_access_token_key_id text,
|
||||
oauth_refresh_token_key_id text
|
||||
oauth_refresh_token_key_id text,
|
||||
oauth_extra jsonb
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN external_auth_links.oauth_access_token_key_id IS 'The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted';
|
||||
|
@ -0,0 +1 @@
|
||||
ALTER TABLE external_auth_links DROP COLUMN "oauth_extra";
|
@ -0,0 +1 @@
|
||||
ALTER TABLE external_auth_links ADD COLUMN "oauth_extra" jsonb;
|
@ -1680,7 +1680,8 @@ type ExternalAuthLink struct {
|
||||
// The ID of the key used to encrypt the OAuth access token. If this is NULL, the access token is not encrypted
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
// The ID of the key used to encrypt the OAuth refresh token. If this is NULL, the refresh token is not encrypted
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
|
@ -751,7 +751,7 @@ func (q *sqlQuerier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest strin
|
||||
}
|
||||
|
||||
const getExternalAuthLink = `-- name: GetExternalAuthLink :one
|
||||
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE provider_id = $1 AND user_id = $2
|
||||
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE provider_id = $1 AND user_id = $2
|
||||
`
|
||||
|
||||
type GetExternalAuthLinkParams struct {
|
||||
@ -772,12 +772,13 @@ func (q *sqlQuerier) GetExternalAuthLink(ctx context.Context, arg GetExternalAut
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.OAuthExtra,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getExternalAuthLinksByUserID = `-- name: GetExternalAuthLinksByUserID :many
|
||||
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id FROM external_auth_links WHERE user_id = $1
|
||||
SELECT provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra FROM external_auth_links WHERE user_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error) {
|
||||
@ -799,6 +800,7 @@ func (q *sqlQuerier) GetExternalAuthLinksByUserID(ctx context.Context, userID uu
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.OAuthExtra,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -823,7 +825,8 @@ INSERT INTO external_auth_links (
|
||||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
oauth_expiry,
|
||||
oauth_extra
|
||||
) VALUES (
|
||||
$1,
|
||||
$2,
|
||||
@ -833,20 +836,22 @@ INSERT INTO external_auth_links (
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9
|
||||
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
$9,
|
||||
$10
|
||||
) RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra
|
||||
`
|
||||
|
||||
type InsertExternalAuthLinkParams struct {
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExternalAuthLinkParams) (ExternalAuthLink, error) {
|
||||
@ -860,6 +865,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
arg.OAuthExtra,
|
||||
)
|
||||
var i ExternalAuthLink
|
||||
err := row.Scan(
|
||||
@ -872,6 +878,7 @@ func (q *sqlQuerier) InsertExternalAuthLink(ctx context.Context, arg InsertExter
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.OAuthExtra,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -883,19 +890,21 @@ UPDATE external_auth_links SET
|
||||
oauth_access_token_key_id = $5,
|
||||
oauth_refresh_token = $6,
|
||||
oauth_refresh_token_key_id = $7,
|
||||
oauth_expiry = $8
|
||||
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id
|
||||
oauth_expiry = $8,
|
||||
oauth_extra = $9
|
||||
WHERE provider_id = $1 AND user_id = $2 RETURNING provider_id, user_id, created_at, updated_at, oauth_access_token, oauth_refresh_token, oauth_expiry, oauth_access_token_key_id, oauth_refresh_token_key_id, oauth_extra
|
||||
`
|
||||
|
||||
type UpdateExternalAuthLinkParams struct {
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
ProviderID string `db:"provider_id" json:"provider_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthAccessTokenKeyID sql.NullString `db:"oauth_access_token_key_id" json:"oauth_access_token_key_id"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthRefreshTokenKeyID sql.NullString `db:"oauth_refresh_token_key_id" json:"oauth_refresh_token_key_id"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
OAuthExtra pqtype.NullRawMessage `db:"oauth_extra" json:"oauth_extra"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) {
|
||||
@ -908,6 +917,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthRefreshTokenKeyID,
|
||||
arg.OAuthExpiry,
|
||||
arg.OAuthExtra,
|
||||
)
|
||||
var i ExternalAuthLink
|
||||
err := row.Scan(
|
||||
@ -920,6 +930,7 @@ func (q *sqlQuerier) UpdateExternalAuthLink(ctx context.Context, arg UpdateExter
|
||||
&i.OAuthExpiry,
|
||||
&i.OAuthAccessTokenKeyID,
|
||||
&i.OAuthRefreshTokenKeyID,
|
||||
&i.OAuthExtra,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -14,7 +14,8 @@ INSERT INTO external_auth_links (
|
||||
oauth_access_token_key_id,
|
||||
oauth_refresh_token,
|
||||
oauth_refresh_token_key_id,
|
||||
oauth_expiry
|
||||
oauth_expiry,
|
||||
oauth_extra
|
||||
) VALUES (
|
||||
$1,
|
||||
$2,
|
||||
@ -24,7 +25,8 @@ INSERT INTO external_auth_links (
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9
|
||||
$9,
|
||||
$10
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateExternalAuthLink :one
|
||||
@ -34,5 +36,6 @@ UPDATE external_auth_links SET
|
||||
oauth_access_token_key_id = $5,
|
||||
oauth_refresh_token = $6,
|
||||
oauth_refresh_token_key_id = $7,
|
||||
oauth_expiry = $8
|
||||
oauth_expiry = $8,
|
||||
oauth_extra = $9
|
||||
WHERE provider_id = $1 AND user_id = $2 RETURNING *;
|
||||
|
@ -53,6 +53,7 @@ overrides:
|
||||
oauth_id_token: OAuthIDToken
|
||||
oauth_refresh_token: OAuthRefreshToken
|
||||
oauth_refresh_token_key_id: OAuthRefreshTokenKeyID
|
||||
oauth_extra: OAuthExtra
|
||||
parameter_type_system_hcl: ParameterTypeSystemHCL
|
||||
userstatus: UserStatus
|
||||
gitsshkey: GitSSHKey
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
)
|
||||
|
||||
// @Summary Get external auth by ID
|
||||
@ -132,6 +133,8 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
|
||||
OAuthExpiry: token.Expiry,
|
||||
// No extra data from device auth!
|
||||
OAuthExtra: pqtype.NullRawMessage{},
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
@ -150,6 +153,7 @@ func (api *API) postExternalAuthDeviceByID(rw http.ResponseWriter, r *http.Reque
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthExtra: pqtype.NullRawMessage{},
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
@ -201,7 +205,15 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
_, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
|
||||
extra, err := externalAuthConfig.GenerateTokenExtra(state.Token)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to generate token extra.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
_, err = api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
|
||||
ProviderID: externalAuthConfig.ID,
|
||||
UserID: apiKey.UserID,
|
||||
})
|
||||
@ -224,6 +236,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will set as required
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
OAuthExtra: extra,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
@ -242,6 +255,7 @@ func (api *API) externalAuthCallback(externalAuthConfig *externalauth.Config) ht
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
OAuthExtra: extra,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
xgithub "golang.org/x/oauth2/github"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@ -44,6 +45,14 @@ type Config struct {
|
||||
// DisplayIcon is the path to an image that will be displayed to the user.
|
||||
DisplayIcon string
|
||||
|
||||
// ExtraTokenKeys is a list of extra properties to
|
||||
// store in the database returned from the token endpoint.
|
||||
//
|
||||
// e.g. Slack returns `authed_user` in the token which is
|
||||
// a payload that contains information about the authenticated
|
||||
// user.
|
||||
ExtraTokenKeys []string
|
||||
|
||||
// NoRefresh stops Coder from using the refresh token
|
||||
// to renew the access token.
|
||||
//
|
||||
@ -69,6 +78,25 @@ type Config struct {
|
||||
AppInstallationsURL string
|
||||
}
|
||||
|
||||
// GenerateTokenExtra generates the extra token data to store in the database.
|
||||
func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) {
|
||||
if len(c.ExtraTokenKeys) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
extraMap := map[string]interface{}{}
|
||||
for _, key := range c.ExtraTokenKeys {
|
||||
extraMap[key] = token.Extra(key)
|
||||
}
|
||||
data, err := json.Marshal(extraMap)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, err
|
||||
}
|
||||
return pqtype.NullRawMessage{
|
||||
RawMessage: data,
|
||||
Valid: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken automatically refreshes the token if expired and permitted.
|
||||
// It returns the token and a bool indicating if the token is valid.
|
||||
func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAuthLink database.ExternalAuthLink) (database.ExternalAuthLink, bool, error) {
|
||||
@ -101,6 +129,12 @@ func (c *Config) RefreshToken(ctx context.Context, db database.Store, externalAu
|
||||
// we aren't trying to surface an error, we're just trying to obtain a valid token.
|
||||
return externalAuthLink, false, nil
|
||||
}
|
||||
|
||||
extra, err := c.GenerateTokenExtra(token)
|
||||
if err != nil {
|
||||
return externalAuthLink, false, xerrors.Errorf("generate token extra: %w", err)
|
||||
}
|
||||
|
||||
r := retry.New(50*time.Millisecond, 200*time.Millisecond)
|
||||
// See the comment below why the retry and cancel is required.
|
||||
retryCtx, retryCtxCancel := context.WithTimeout(ctx, time.Second)
|
||||
@ -135,6 +169,7 @@ validate:
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthExtra: extra,
|
||||
})
|
||||
if err != nil {
|
||||
return updatedAuthLink, false, xerrors.Errorf("update external auth link: %w", err)
|
||||
@ -539,6 +574,14 @@ var defaults = map[codersdk.EnhancedExternalAuthProvider]codersdk.ExternalAuthCo
|
||||
DeviceCodeURL: "https://github.com/login/device/code",
|
||||
AppInstallationsURL: "https://api.github.com/user/installations",
|
||||
},
|
||||
codersdk.EnhancedExternalAuthProviderSlack: {
|
||||
AuthURL: "https://slack.com/oauth/v2/authorize",
|
||||
TokenURL: "https://slack.com/api/oauth.v2.access",
|
||||
DisplayName: "Slack",
|
||||
DisplayIcon: "/icon/slack.svg",
|
||||
// See: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
ExtraTokenKeys: []string{"authed_user"},
|
||||
},
|
||||
}
|
||||
|
||||
// jwtConfig is a new OAuth2 config that uses a custom
|
||||
|
@ -2,6 +2,7 @@ package externalauth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
@ -43,7 +44,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
return nil, xerrors.New("should not be called")
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *externalauth.Config) {
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.NoRefresh = true
|
||||
},
|
||||
})
|
||||
@ -74,7 +75,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *externalauth.Config) {
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.NoRefresh = true
|
||||
},
|
||||
})
|
||||
@ -117,7 +118,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
return jwt.MapClaims{}, xerrors.New(staticError)
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *externalauth.Config) {
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
},
|
||||
})
|
||||
|
||||
@ -142,7 +143,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *externalauth.Config) {
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
},
|
||||
})
|
||||
|
||||
@ -175,7 +176,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
return jwt.MapClaims{}, oidctest.StatusError(http.StatusUnauthorized, xerrors.New(staticError))
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *externalauth.Config) {
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
},
|
||||
})
|
||||
@ -205,7 +206,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *externalauth.Config) {
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
},
|
||||
})
|
||||
@ -236,7 +237,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
return jwt.MapClaims{}, nil
|
||||
}),
|
||||
},
|
||||
GitConfigOpt: func(cfg *externalauth.Config) {
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
|
||||
},
|
||||
DB: db,
|
||||
@ -260,6 +261,41 @@ func TestRefreshToken(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updated.OAuthAccessToken, dbLink.OAuthAccessToken, "token is updated in the DB")
|
||||
})
|
||||
|
||||
t.Run("WithExtra", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := dbfake.New()
|
||||
fake, config, link := setupOauth2Test(t, testConfig{
|
||||
FakeIDPOpts: []oidctest.FakeIDPOpt{
|
||||
oidctest.WithMutateToken(func(token map[string]interface{}) {
|
||||
token["authed_user"] = map[string]interface{}{
|
||||
"access_token": token["access_token"],
|
||||
}
|
||||
}),
|
||||
},
|
||||
ExternalAuthOpt: func(cfg *externalauth.Config) {
|
||||
cfg.Type = codersdk.EnhancedExternalAuthProviderSlack.String()
|
||||
cfg.ExtraTokenKeys = []string{"authed_user"}
|
||||
cfg.ValidateURL = ""
|
||||
},
|
||||
DB: db,
|
||||
})
|
||||
|
||||
ctx := oidc.ClientContext(context.Background(), fake.HTTPClient(nil))
|
||||
// Force a refresh
|
||||
link.OAuthExpiry = expired
|
||||
|
||||
updated, ok, err := config.RefreshToken(ctx, db, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
require.True(t, updated.OAuthExtra.Valid)
|
||||
extra := map[string]interface{}{}
|
||||
require.NoError(t, json.Unmarshal(updated.OAuthExtra.RawMessage, &extra))
|
||||
mapping, ok := extra["authed_user"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
require.Equal(t, updated.OAuthAccessToken, mapping["access_token"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestConvertYAML(t *testing.T) {
|
||||
@ -344,7 +380,7 @@ func TestConvertYAML(t *testing.T) {
|
||||
type testConfig struct {
|
||||
FakeIDPOpts []oidctest.FakeIDPOpt
|
||||
CoderOIDCConfigOpts []func(cfg *coderd.OIDCConfig)
|
||||
GitConfigOpt func(cfg *externalauth.Config)
|
||||
ExternalAuthOpt func(cfg *externalauth.Config)
|
||||
// If DB is passed in, the link will be inserted into the DB.
|
||||
DB database.Store
|
||||
}
|
||||
@ -367,7 +403,7 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext
|
||||
ID: providerID,
|
||||
ValidateURL: fake.WellknownConfig().UserInfoURL,
|
||||
}
|
||||
settings.GitConfigOpt(config)
|
||||
settings.ExternalAuthOpt(config)
|
||||
|
||||
oauthToken, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
|
||||
"email": "test@coder.com",
|
||||
|
Reference in New Issue
Block a user