mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
feat: Add GitHub OAuth (#1050)
* Initial oauth * Add Github authentication * Add AuthMethods endpoint * Add frontend * Rename basic authentication to password * Add flags for configuring GitHub auth * Remove name from API keys * Fix authmethods in test * Add stories and display auth methods error
This commit is contained in:
@ -42,6 +42,7 @@ type Options struct {
|
||||
AWSCertificates awsidentity.Certificates
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
GithubOAuth2Config *GithubOAuth2Config
|
||||
ICEServers []webrtc.ICEServer
|
||||
SecureAuthCookie bool
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
@ -62,6 +63,9 @@ func New(options *Options) (http.Handler, func()) {
|
||||
api := &api{
|
||||
Options: options,
|
||||
}
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
})
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
@ -86,7 +90,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
// This number is arbitrary, but reading/writing
|
||||
// file content is expensive so it should be small.
|
||||
httpmw.RateLimitPerMinute(12),
|
||||
@ -96,7 +100,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/organizations/{organization}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractOrganizationParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.organization)
|
||||
@ -109,7 +113,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
})
|
||||
r.Route("/parameters/{scope}/{id}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Post("/", api.postParameter)
|
||||
r.Get("/", api.parameters)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
@ -118,7 +122,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/templates/{template}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractTemplateParam(options.Database),
|
||||
httpmw.ExtractOrganizationParam(options.Database),
|
||||
)
|
||||
@ -132,7 +136,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/templateversions/{templateversion}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractTemplateVersionParam(options.Database),
|
||||
httpmw.ExtractOrganizationParam(options.Database),
|
||||
)
|
||||
@ -154,8 +158,15 @@ func New(options *Options) (http.Handler, func()) {
|
||||
r.Post("/first", api.postFirstUser)
|
||||
r.Post("/login", api.postLogin)
|
||||
r.Post("/logout", api.postLogout)
|
||||
r.Get("/authmethods", api.userAuthMethods)
|
||||
r.Route("/oauth2", func(r chi.Router) {
|
||||
r.Route("/github", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
|
||||
r.Get("/callback", api.userOAuth2Github)
|
||||
})
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Post("/", api.postUsers)
|
||||
r.Get("/", api.users)
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
@ -193,7 +204,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/{workspaceagent}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceAgentParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspaceAgent)
|
||||
@ -204,7 +215,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceResourceParam(options.Database),
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
@ -212,7 +223,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/workspaces/{workspace}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspace)
|
||||
@ -230,7 +241,7 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceBuildParam(options.Database),
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
|
@ -53,6 +53,7 @@ import (
|
||||
type Options struct {
|
||||
AWSCertificates awsidentity.Certificates
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GithubOAuth2Config *coderd.GithubOAuth2Config
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
APIRateLimit int
|
||||
@ -123,6 +124,7 @@ func New(t *testing.T, options *Options) *codersdk.Client {
|
||||
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
TURNServer: turnServer,
|
||||
|
@ -434,6 +434,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW
|
||||
return workspaces, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
if len(q.organizations) == 0 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
return q.organizations, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
@ -856,21 +866,18 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
|
||||
|
||||
//nolint:gosimple
|
||||
key := database.APIKey{
|
||||
ID: arg.ID,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
UserID: arg.UserID,
|
||||
Application: arg.Application,
|
||||
Name: arg.Name,
|
||||
LastUsed: arg.LastUsed,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LoginType: arg.LoginType,
|
||||
OIDCAccessToken: arg.OIDCAccessToken,
|
||||
OIDCRefreshToken: arg.OIDCRefreshToken,
|
||||
OIDCIDToken: arg.OIDCIDToken,
|
||||
OIDCExpiry: arg.OIDCExpiry,
|
||||
DevurlToken: arg.DevurlToken,
|
||||
ID: arg.ID,
|
||||
HashedSecret: arg.HashedSecret,
|
||||
UserID: arg.UserID,
|
||||
ExpiresAt: arg.ExpiresAt,
|
||||
CreatedAt: arg.CreatedAt,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
LastUsed: arg.LastUsed,
|
||||
LoginType: arg.LoginType,
|
||||
OAuthAccessToken: arg.OAuthAccessToken,
|
||||
OAuthRefreshToken: arg.OAuthRefreshToken,
|
||||
OAuthIDToken: arg.OAuthIDToken,
|
||||
OAuthExpiry: arg.OAuthExpiry,
|
||||
}
|
||||
q.apiKeys = append(q.apiKeys, key)
|
||||
return key, nil
|
||||
@ -1185,9 +1192,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
|
||||
}
|
||||
apiKey.LastUsed = arg.LastUsed
|
||||
apiKey.ExpiresAt = arg.ExpiresAt
|
||||
apiKey.OIDCAccessToken = arg.OIDCAccessToken
|
||||
apiKey.OIDCRefreshToken = arg.OIDCRefreshToken
|
||||
apiKey.OIDCExpiry = arg.OIDCExpiry
|
||||
apiKey.OAuthAccessToken = arg.OAuthAccessToken
|
||||
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
|
||||
apiKey.OAuthExpiry = arg.OAuthExpiry
|
||||
q.apiKeys[index] = apiKey
|
||||
return nil
|
||||
}
|
||||
|
16
coderd/database/dump.sql
generated
16
coderd/database/dump.sql
generated
@ -14,9 +14,8 @@ CREATE TYPE log_source AS ENUM (
|
||||
);
|
||||
|
||||
CREATE TYPE login_type AS ENUM (
|
||||
'built-in',
|
||||
'saml',
|
||||
'oidc'
|
||||
'password',
|
||||
'github'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_destination_scheme AS ENUM (
|
||||
@ -67,18 +66,15 @@ CREATE TABLE api_keys (
|
||||
id text NOT NULL,
|
||||
hashed_secret bytea NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
application boolean NOT NULL,
|
||||
name text NOT NULL,
|
||||
last_used timestamp with time zone NOT NULL,
|
||||
expires_at timestamp with time zone NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
oidc_access_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_id_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
devurl_token boolean DEFAULT false NOT NULL
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_id_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE files (
|
||||
|
@ -4,14 +4,9 @@
|
||||
-- All tables and types are stolen from:
|
||||
-- https://github.com/coder/m/blob/47b6fc383347b9f9fab424d829c482defd3e1fe2/product/coder/pkg/database/dump.sql
|
||||
|
||||
--
|
||||
-- Name: users; Type: TABLE; Schema: public; Owner: coder
|
||||
--
|
||||
|
||||
CREATE TYPE login_type AS ENUM (
|
||||
'built-in',
|
||||
'saml',
|
||||
'oidc'
|
||||
'password',
|
||||
'github'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
@ -31,10 +26,6 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users USING btree (email);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users USING btree (username);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS users_username_lower_idx ON users USING btree (lower(username));
|
||||
|
||||
--
|
||||
-- Name: organizations; Type: TABLE; Schema: Owner: coder
|
||||
--
|
||||
|
||||
CREATE TABLE IF NOT EXISTS organizations (
|
||||
id uuid NOT NULL,
|
||||
name text NOT NULL,
|
||||
@ -68,18 +59,15 @@ CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id text NOT NULL,
|
||||
hashed_secret bytea NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
application boolean NOT NULL,
|
||||
name text NOT NULL,
|
||||
last_used timestamp with time zone NOT NULL,
|
||||
expires_at timestamp with time zone NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
oidc_access_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_id_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
devurl_token boolean DEFAULT false NOT NULL,
|
||||
oauth_access_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_id_token text DEFAULT ''::text NOT NULL,
|
||||
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
|
@ -56,9 +56,8 @@ func (e *LogSource) Scan(src interface{}) error {
|
||||
type LoginType string
|
||||
|
||||
const (
|
||||
LoginTypeBuiltIn LoginType = "built-in"
|
||||
LoginTypeSaml LoginType = "saml"
|
||||
LoginTypeOIDC LoginType = "oidc"
|
||||
LoginTypePassword LoginType = "password"
|
||||
LoginTypeGithub LoginType = "github"
|
||||
)
|
||||
|
||||
func (e *LoginType) Scan(src interface{}) error {
|
||||
@ -230,21 +229,18 @@ func (e *WorkspaceTransition) Scan(src interface{}) error {
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Application bool `db:"application" json:"application"`
|
||||
Name string `db:"name" json:"name"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
|
||||
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
|
||||
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
|
||||
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
|
||||
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
|
@ -18,6 +18,7 @@ type querier interface {
|
||||
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
|
||||
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
|
||||
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
|
||||
GetOrganizations(ctx context.Context) ([]Organization, error)
|
||||
GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error)
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error)
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
|
||||
const getAPIKeyByID = `-- name: GetAPIKeyByID :one
|
||||
SELECT
|
||||
id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
|
||||
id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry
|
||||
FROM
|
||||
api_keys
|
||||
WHERE
|
||||
@ -31,18 +31,15 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
|
||||
&i.ID,
|
||||
&i.HashedSecret,
|
||||
&i.UserID,
|
||||
&i.Application,
|
||||
&i.Name,
|
||||
&i.LastUsed,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LoginType,
|
||||
&i.OIDCAccessToken,
|
||||
&i.OIDCRefreshToken,
|
||||
&i.OIDCIDToken,
|
||||
&i.OIDCExpiry,
|
||||
&i.DevurlToken,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthIDToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -53,55 +50,33 @@ INSERT INTO
|
||||
id,
|
||||
hashed_secret,
|
||||
user_id,
|
||||
application,
|
||||
"name",
|
||||
last_used,
|
||||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
oidc_access_token,
|
||||
oidc_refresh_token,
|
||||
oidc_id_token,
|
||||
oidc_expiry,
|
||||
devurl_token
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_id_token,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15
|
||||
) RETURNING id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry
|
||||
`
|
||||
|
||||
type InsertAPIKeyParams struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Application bool `db:"application" json:"application"`
|
||||
Name string `db:"name" json:"name"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
|
||||
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
|
||||
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
|
||||
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
|
||||
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) {
|
||||
@ -109,36 +84,30 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
|
||||
arg.ID,
|
||||
arg.HashedSecret,
|
||||
arg.UserID,
|
||||
arg.Application,
|
||||
arg.Name,
|
||||
arg.LastUsed,
|
||||
arg.ExpiresAt,
|
||||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.LoginType,
|
||||
arg.OIDCAccessToken,
|
||||
arg.OIDCRefreshToken,
|
||||
arg.OIDCIDToken,
|
||||
arg.OIDCExpiry,
|
||||
arg.DevurlToken,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthIDToken,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
var i APIKey
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.HashedSecret,
|
||||
&i.UserID,
|
||||
&i.Application,
|
||||
&i.Name,
|
||||
&i.LastUsed,
|
||||
&i.ExpiresAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.LoginType,
|
||||
&i.OIDCAccessToken,
|
||||
&i.OIDCRefreshToken,
|
||||
&i.OIDCIDToken,
|
||||
&i.OIDCExpiry,
|
||||
&i.DevurlToken,
|
||||
&i.OAuthAccessToken,
|
||||
&i.OAuthRefreshToken,
|
||||
&i.OAuthIDToken,
|
||||
&i.OAuthExpiry,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@ -149,20 +118,20 @@ UPDATE
|
||||
SET
|
||||
last_used = $2,
|
||||
expires_at = $3,
|
||||
oidc_access_token = $4,
|
||||
oidc_refresh_token = $5,
|
||||
oidc_expiry = $6
|
||||
oauth_access_token = $4,
|
||||
oauth_refresh_token = $5,
|
||||
oauth_expiry = $6
|
||||
WHERE
|
||||
id = $1
|
||||
`
|
||||
|
||||
type UpdateAPIKeyByIDParams struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
|
||||
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
|
||||
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
|
||||
ID string `db:"id" json:"id"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
|
||||
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
|
||||
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error {
|
||||
@ -170,9 +139,9 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
||||
arg.ID,
|
||||
arg.LastUsed,
|
||||
arg.ExpiresAt,
|
||||
arg.OIDCAccessToken,
|
||||
arg.OIDCRefreshToken,
|
||||
arg.OIDCExpiry,
|
||||
arg.OAuthAccessToken,
|
||||
arg.OAuthRefreshToken,
|
||||
arg.OAuthExpiry,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@ -453,6 +422,42 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOrganizations = `-- name: GetOrganizations :many
|
||||
SELECT
|
||||
id, name, description, created_at, updated_at
|
||||
FROM
|
||||
organizations
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOrganizations(ctx context.Context) ([]Organization, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getOrganizations)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Organization
|
||||
for rows.Next() {
|
||||
var i Organization
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many
|
||||
SELECT
|
||||
id, name, description, created_at, updated_at
|
||||
|
@ -14,37 +14,18 @@ INSERT INTO
|
||||
id,
|
||||
hashed_secret,
|
||||
user_id,
|
||||
application,
|
||||
"name",
|
||||
last_used,
|
||||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
oidc_access_token,
|
||||
oidc_refresh_token,
|
||||
oidc_id_token,
|
||||
oidc_expiry,
|
||||
devurl_token
|
||||
oauth_access_token,
|
||||
oauth_refresh_token,
|
||||
oauth_id_token,
|
||||
oauth_expiry
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15
|
||||
) RETURNING *;
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING *;
|
||||
|
||||
-- name: UpdateAPIKeyByID :exec
|
||||
UPDATE
|
||||
@ -52,8 +33,8 @@ UPDATE
|
||||
SET
|
||||
last_used = $2,
|
||||
expires_at = $3,
|
||||
oidc_access_token = $4,
|
||||
oidc_refresh_token = $5,
|
||||
oidc_expiry = $6
|
||||
oauth_access_token = $4,
|
||||
oauth_refresh_token = $5,
|
||||
oauth_expiry = $6
|
||||
WHERE
|
||||
id = $1;
|
||||
|
@ -1,3 +1,9 @@
|
||||
-- name: GetOrganizations :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organizations;
|
||||
|
||||
-- name: GetOrganizationByID :one
|
||||
SELECT
|
||||
*
|
||||
|
@ -21,10 +21,10 @@ overrides:
|
||||
rename:
|
||||
api_key: APIKey
|
||||
login_type_oidc: LoginTypeOIDC
|
||||
oidc_access_token: OIDCAccessToken
|
||||
oidc_expiry: OIDCExpiry
|
||||
oidc_id_token: OIDCIDToken
|
||||
oidc_refresh_token: OIDCRefreshToken
|
||||
oauth_access_token: OAuthAccessToken
|
||||
oauth_expiry: OAuthExpiry
|
||||
oauth_id_token: OAuthIDToken
|
||||
oauth_refresh_token: OAuthRefreshToken
|
||||
parameter_type_system_hcl: ParameterTypeSystemHCL
|
||||
userstatus: UserStatus
|
||||
gitsshkey: GitSSHKey
|
||||
|
@ -20,12 +20,6 @@ import (
|
||||
// AuthCookie represents the name of the cookie the API key is stored in.
|
||||
const AuthCookie = "session_token"
|
||||
|
||||
// OAuth2Config contains a subset of functions exposed from oauth2.Config.
|
||||
// It is abstracted for simple testing.
|
||||
type OAuth2Config interface {
|
||||
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
|
||||
}
|
||||
|
||||
type apiKeyContextKey struct{}
|
||||
|
||||
// APIKey returns the API key from the ExtractAPIKey handler.
|
||||
@ -37,10 +31,16 @@ func APIKey(r *http.Request) database.APIKey {
|
||||
return apiKey
|
||||
}
|
||||
|
||||
// OAuth2Configs is a collection of configurations for OAuth-based authentication.
|
||||
// This should be extended to support other authentication types in the future.
|
||||
type OAuth2Configs struct {
|
||||
Github OAuth2Config
|
||||
}
|
||||
|
||||
// ExtractAPIKey requires authentication using a valid API key.
|
||||
// It handles extending an API key if it comes close to expiry,
|
||||
// updating the last used time in the database.
|
||||
func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler {
|
||||
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(AuthCookie)
|
||||
@ -99,14 +99,24 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
||||
// Tracks if the API key has properties updated!
|
||||
changed := false
|
||||
|
||||
if key.LoginType == database.LoginTypeOIDC {
|
||||
// Check if the OIDC token is expired!
|
||||
if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() {
|
||||
if key.LoginType != database.LoginTypePassword {
|
||||
// Check if the OAuth token is expired!
|
||||
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
|
||||
var oauthConfig OAuth2Config
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = oauth.Github
|
||||
default:
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("unexpected authentication type %q", key.LoginType),
|
||||
})
|
||||
return
|
||||
}
|
||||
// If it is, let's refresh it from the provided config!
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: key.OIDCAccessToken,
|
||||
RefreshToken: key.OIDCRefreshToken,
|
||||
Expiry: key.OIDCExpiry,
|
||||
AccessToken: key.OAuthAccessToken,
|
||||
RefreshToken: key.OAuthRefreshToken,
|
||||
Expiry: key.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
@ -114,9 +124,9 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
||||
})
|
||||
return
|
||||
}
|
||||
key.OIDCAccessToken = token.AccessToken
|
||||
key.OIDCRefreshToken = token.RefreshToken
|
||||
key.OIDCExpiry = token.Expiry
|
||||
key.OAuthAccessToken = token.AccessToken
|
||||
key.OAuthRefreshToken = token.RefreshToken
|
||||
key.OAuthExpiry = token.Expiry
|
||||
key.ExpiresAt = token.Expiry
|
||||
changed = true
|
||||
}
|
||||
@ -136,21 +146,20 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
|
||||
changed = true
|
||||
}
|
||||
// Only update the ExpiresAt once an hour to prevent database spam.
|
||||
// We extend the ExpiresAt to reduce reauthentication.
|
||||
// We extend the ExpiresAt to reduce re-authentication.
|
||||
apiKeyLifetime := 24 * time.Hour
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
LastUsed: key.LastUsed,
|
||||
OIDCAccessToken: key.OIDCAccessToken,
|
||||
OIDCRefreshToken: key.OIDCRefreshToken,
|
||||
OIDCExpiry: key.OIDCExpiry,
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
OAuthAccessToken: key.OAuthAccessToken,
|
||||
OAuthRefreshToken: key.OAuthRefreshToken,
|
||||
OAuthExpiry: key.OAuthExpiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
|
@ -189,7 +189,6 @@ func TestAPIKey(t *testing.T) {
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -207,7 +206,6 @@ func TestAPIKey(t *testing.T) {
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
@ -277,7 +275,7 @@ func TestAPIKey(t *testing.T) {
|
||||
require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("OIDCNotExpired", func(t *testing.T) {
|
||||
t.Run("OAuthNotExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
@ -294,7 +292,7 @@ func TestAPIKey(t *testing.T) {
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
@ -311,7 +309,7 @@ func TestAPIKey(t *testing.T) {
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("OIDCRefresh", func(t *testing.T) {
|
||||
t.Run("OAuthRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
@ -328,9 +326,9 @@ func TestAPIKey(t *testing.T) {
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
LastUsed: database.Now(),
|
||||
OIDCExpiry: database.Now().AddDate(0, 0, -1),
|
||||
OAuthExpiry: database.Now().AddDate(0, 0, -1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
token := &oauth2.Token{
|
||||
@ -338,11 +336,11 @@ func TestAPIKey(t *testing.T) {
|
||||
RefreshToken: "moo",
|
||||
Expiry: database.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
httpmw.ExtractAPIKey(db, &oauth2Config{
|
||||
tokenSource: &oauth2TokenSource{
|
||||
token: func() (*oauth2.Token, error) {
|
||||
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{
|
||||
Github: &oauth2Config{
|
||||
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
|
||||
return token, nil
|
||||
},
|
||||
}),
|
||||
},
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
@ -354,22 +352,28 @@ func TestAPIKey(t *testing.T) {
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
|
||||
require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken)
|
||||
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
|
||||
})
|
||||
}
|
||||
|
||||
type oauth2Config struct {
|
||||
tokenSource *oauth2TokenSource
|
||||
tokenSource oauth2TokenSource
|
||||
}
|
||||
|
||||
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
|
||||
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return o.tokenSource
|
||||
}
|
||||
|
||||
type oauth2TokenSource struct {
|
||||
token func() (*oauth2.Token, error)
|
||||
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
return o.token()
|
||||
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{}, nil
|
||||
}
|
||||
|
||||
type oauth2TokenSource func() (*oauth2.Token, error)
|
||||
|
||||
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
return o()
|
||||
}
|
||||
|
132
coderd/httpmw/oauth2.go
Normal file
132
coderd/httpmw/oauth2.go
Normal file
@ -0,0 +1,132 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2StateCookieName = "oauth_state"
|
||||
oauth2RedirectCookieName = "oauth_redirect"
|
||||
)
|
||||
|
||||
type oauth2StateKey struct{}
|
||||
|
||||
type OAuth2State struct {
|
||||
Token *oauth2.Token
|
||||
Redirect string
|
||||
}
|
||||
|
||||
// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing.
|
||||
// *oauth2.Config should be used instead of implementing this in production.
|
||||
type OAuth2Config interface {
|
||||
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
|
||||
}
|
||||
|
||||
// OAuth2 returns the state from an oauth request.
|
||||
func OAuth2(r *http.Request) OAuth2State {
|
||||
oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State)
|
||||
if !ok {
|
||||
panic("developer error: oauth middleware not provided")
|
||||
}
|
||||
return oauth
|
||||
}
|
||||
|
||||
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
|
||||
// URLs, and handling the exchange inbound. Any route that does not have
|
||||
// a "code" URL parameter will be redirected.
|
||||
func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if config == nil {
|
||||
httpapi.Write(rw, http.StatusPreconditionRequired, httpapi.Response{
|
||||
Message: fmt.Sprintf("The oauth2 method requested is not configured!"),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
// If the code isn't provided, we'll redirect!
|
||||
state, err := cryptorand.String(32)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate state string: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: oauth2StateCookieName,
|
||||
Value: state,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
// Redirect must always be specified, otherwise
|
||||
// an old redirect could apply!
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: oauth2RedirectCookieName,
|
||||
Value: r.URL.Query().Get("redirect"),
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "state must be provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
stateCookie, err := r.Cookie(oauth2StateCookieName)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("%q cookie must be provided", oauth2StateCookieName),
|
||||
})
|
||||
return
|
||||
}
|
||||
if stateCookie.Value != state {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "state mismatched",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var redirect string
|
||||
stateRedirect, err := r.Cookie(oauth2RedirectCookieName)
|
||||
if err == nil {
|
||||
redirect = stateRedirect.Value
|
||||
}
|
||||
|
||||
oauthToken, err := config.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("exchange oauth code: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), oauth2StateKey{}, OAuth2State{
|
||||
Token: oauthToken,
|
||||
Redirect: redirect,
|
||||
})
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
98
coderd/httpmw/oauth2_test.go
Normal file
98
coderd/httpmw/oauth2_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
type testOAuth2Provider struct {
|
||||
}
|
||||
|
||||
func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*testOAuth2Provider) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{
|
||||
AccessToken: "hello",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOAuth2(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NotSetup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("RedirectWithoutCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
location := res.Header().Get("Location")
|
||||
if !assert.NotEmpty(t, location) {
|
||||
return
|
||||
}
|
||||
require.Len(t, res.Result().Cookies(), 2)
|
||||
cookie := res.Result().Cookies()[1]
|
||||
require.Equal(t, "/dashboard", cookie.Value)
|
||||
})
|
||||
t.Run("NoState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("NoStateCookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("MismatchedState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "mismatch",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
})
|
||||
t.Run("ExchangeCodeAndState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/?code=test&state=something", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: "something",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_redirect",
|
||||
Value: "/dashboard",
|
||||
})
|
||||
res := httptest.NewRecorder()
|
||||
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
require.Equal(t, "/dashboard", state.Redirect)
|
||||
})).ServeHTTP(res, req)
|
||||
})
|
||||
}
|
@ -41,7 +41,7 @@ func TestOrganizationParam(t *testing.T) {
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
@ -40,7 +40,7 @@ func TestTemplateParam(t *testing.T) {
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
@ -40,7 +40,7 @@ func TestTemplateVersionParam(t *testing.T) {
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
@ -40,7 +40,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
@ -40,7 +40,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
@ -40,7 +40,7 @@ func TestWorkspaceParam(t *testing.T) {
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
LoginType: database.LoginTypePassword,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
|
155
coderd/userauth.go
Normal file
155
coderd/userauth.go
Normal file
@ -0,0 +1,155 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// GithubOAuth2Provider exposes required functions for the Github authentication flow.
|
||||
type GithubOAuth2Config struct {
|
||||
httpmw.OAuth2Config
|
||||
AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error)
|
||||
ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error)
|
||||
ListOrganizationMemberships func(ctx context.Context, client *http.Client) ([]*github.Membership, error)
|
||||
|
||||
AllowSignups bool
|
||||
AllowOrganizations []string
|
||||
}
|
||||
|
||||
func (api *api) userAuthMethods(rw http.ResponseWriter, _ *http.Request) {
|
||||
httpapi.Write(rw, http.StatusOK, codersdk.AuthMethods{
|
||||
Password: true,
|
||||
Github: api.GithubOAuth2Config != nil,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *api) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
|
||||
oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token))
|
||||
memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get authenticated github user organizations: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
var selectedMembership *github.Membership
|
||||
for _, membership := range memberships {
|
||||
for _, allowed := range api.GithubOAuth2Config.AllowOrganizations {
|
||||
if *membership.Organization.Login != allowed {
|
||||
continue
|
||||
}
|
||||
selectedMembership = membership
|
||||
break
|
||||
}
|
||||
}
|
||||
if selectedMembership == nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("You aren't a member of the authorized Github organizations!"),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get personal github user: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var user database.User
|
||||
// Search for existing users with matching and verified emails.
|
||||
// If a verified GitHub email matches a Coder user, we will return.
|
||||
for _, email := range emails {
|
||||
if email.Verified == nil {
|
||||
continue
|
||||
}
|
||||
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
Email: *email.Email,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get user by email: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !*email.Verified {
|
||||
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
|
||||
Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// If the user doesn't exist, create a new one!
|
||||
if user.ID == uuid.Nil {
|
||||
if !api.GithubOAuth2Config.AllowSignups {
|
||||
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
|
||||
Message: "Signups are disabled for Github authentication!",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var organizationID uuid.UUID
|
||||
organizations, _ := api.Database.GetOrganizations(r.Context())
|
||||
if len(organizations) > 0 {
|
||||
// Add the user to the first organization. Once multi-organization
|
||||
// support is added, we should enable a configuration map of user
|
||||
// email to organization.
|
||||
organizationID = organizations[0].ID
|
||||
}
|
||||
ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get authenticated github user: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: *ghUser.Email,
|
||||
Username: *ghUser.Login,
|
||||
OrganizationID: organizationID,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("create user: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeGithub,
|
||||
OAuthAccessToken: state.Token.AccessToken,
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
}
|
||||
|
||||
redirect := state.Redirect
|
||||
if redirect == "" {
|
||||
redirect = "/"
|
||||
}
|
||||
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
205
coderd/userauth_test.go
Normal file
205
coderd/userauth_test.go
Normal file
@ -0,0 +1,205 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
type oauth2Config struct{}
|
||||
|
||||
func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUserAuthMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Password", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
methods, err := client.AuthMethods(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, methods.Password)
|
||||
require.False(t, methods.Github)
|
||||
})
|
||||
t.Run("Github", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{},
|
||||
})
|
||||
methods, err := client.AuthMethods(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, methods.Password)
|
||||
require.True(t, methods.Github)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserOAuth2Github(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NotInAllowedOrganization", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("kyle"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
})
|
||||
t.Run("UnverifiedEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{{
|
||||
Email: github.String("testuser@coder.com"),
|
||||
Verified: github.Bool(false),
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
t.Run("BlockSignups", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
})
|
||||
t.Run("Signup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowSignups: true,
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{
|
||||
Login: github.String("kyle"),
|
||||
Email: github.String("kyle@coder.com"),
|
||||
}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
t.Run("Login", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
Organization: &github.Organization{
|
||||
Login: github.String("coder"),
|
||||
},
|
||||
}}, nil
|
||||
},
|
||||
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
||||
return &github.User{}, nil
|
||||
},
|
||||
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
||||
return []*github.UserEmail{{
|
||||
Email: github.String("testuser@coder.com"),
|
||||
Verified: github.Bool(true),
|
||||
}}, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
resp := oauth2Callback(t, client)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
state := "somestate"
|
||||
oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state)
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequest("GET", oauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state,
|
||||
})
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = res.Body.Close()
|
||||
})
|
||||
return res
|
||||
}
|
297
coderd/users.go
297
coderd/users.go
@ -1,6 +1,7 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
@ -71,66 +72,10 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := userpassword.Hash(createUser.Password)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("hash password: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create the user, organization, and membership to the user.
|
||||
var user database.User
|
||||
var organization database.Organization
|
||||
err = api.Database.InTx(func(db database.Store) error {
|
||||
user, err = api.Database.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: createUser.Email,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
Username: createUser.Username,
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user gitsshkey: %w", err)
|
||||
}
|
||||
|
||||
organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: uuid.New(),
|
||||
Name: createUser.OrganizationName,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization: %w", err)
|
||||
}
|
||||
_, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Roles: []string{"organization-admin"},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization member: %w", err)
|
||||
}
|
||||
return nil
|
||||
user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: createUser.Email,
|
||||
Username: createUser.Username,
|
||||
Password: createUser.Password,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
@ -141,7 +86,7 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.CreateFirstUserResponse{
|
||||
UserID: user.ID,
|
||||
OrganizationID: organization.ID,
|
||||
OrganizationID: organizationID,
|
||||
})
|
||||
}
|
||||
|
||||
@ -262,56 +207,7 @@ func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := userpassword.Hash(createUser.Password)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("hash password: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err = api.Database.InTx(func(db database.Store) error {
|
||||
user, err = db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: createUser.Email,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
Username: createUser.Username,
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user gitsshkey: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Roles: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization member: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
user, _, err := api.createUser(r.Context(), createUser)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: err.Error(),
|
||||
@ -542,41 +438,13 @@ func (api *api) postLogin(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
|
||||
})
|
||||
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: user.ID,
|
||||
ExpiresAt: database.Now().Add(24 * time.Hour),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert api key: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: api.SecureAuthCookie,
|
||||
})
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.LoginWithPasswordResponse{
|
||||
SessionToken: sessionToken,
|
||||
@ -595,35 +463,15 @@ func (api *api) postAPIKey(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: apiKey.UserID,
|
||||
ExpiresAt: database.Now().AddDate(1, 0, 0), // Expire after 1 year (same as v1)
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert api key: %s", err.Error()),
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
}
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
generatedAPIKey := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: generatedAPIKey})
|
||||
httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: sessionToken})
|
||||
}
|
||||
|
||||
// Clear the user's session cookie
|
||||
@ -984,6 +832,117 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) {
|
||||
return id, secret, nil
|
||||
}
|
||||
|
||||
func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) {
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: params.UserID,
|
||||
ExpiresAt: database.Now().Add(24 * time.Hour),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: params.LoginType,
|
||||
OAuthAccessToken: params.OAuthAccessToken,
|
||||
OAuthRefreshToken: params.OAuthRefreshToken,
|
||||
OAuthIDToken: params.OAuthIDToken,
|
||||
OAuthExpiry: params.OAuthExpiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("insert api key: %s", err.Error()),
|
||||
})
|
||||
return "", false
|
||||
}
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: api.SecureAuthCookie,
|
||||
})
|
||||
return sessionToken, true
|
||||
}
|
||||
|
||||
func (api *api) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) {
|
||||
var user database.User
|
||||
return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error {
|
||||
// If no organization is provided, create a new one for the user.
|
||||
if req.OrganizationID == uuid.Nil {
|
||||
organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
|
||||
ID: uuid.New(),
|
||||
Name: req.Username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization: %w", err)
|
||||
}
|
||||
req.OrganizationID = organization.ID
|
||||
}
|
||||
|
||||
params := database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: req.Email,
|
||||
Username: req.Username,
|
||||
LoginType: database.LoginTypePassword,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
}
|
||||
// If a user signs up with OAuth, they can have no password!
|
||||
if req.Password != "" {
|
||||
hashedPassword, err := userpassword.Hash(req.Password)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("hash password: %w", err)
|
||||
}
|
||||
params.HashedPassword = []byte(hashedPassword)
|
||||
}
|
||||
|
||||
var err error
|
||||
user, err = db.InsertUser(ctx, params)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("generate user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertGitSSHKey(ctx, database.InsertGitSSHKeyParams{
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user gitsshkey: %w", err)
|
||||
}
|
||||
_, err = db.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
|
||||
OrganizationID: req.OrganizationID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
Roles: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create organization member: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func convertUser(user database.User) codersdk.User {
|
||||
return codersdk.User{
|
||||
ID: user.ID,
|
||||
|
@ -241,13 +241,14 @@ func TestUpdateUserProfile(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
existentUser, _ := client.CreateUser(context.Background(), codersdk.CreateUserRequest{
|
||||
existentUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{
|
||||
Email: "bruno@coder.com",
|
||||
Username: "bruno",
|
||||
Password: "password",
|
||||
OrganizationID: user.OrganizationID,
|
||||
})
|
||||
_, err := client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{
|
||||
require.NoError(t, err)
|
||||
_, err = client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{
|
||||
Username: existentUser.Username,
|
||||
Email: "newemail@coder.com",
|
||||
})
|
||||
|
Reference in New Issue
Block a user