feat: Add GitHub OAuth (#1050)

* Initial oauth

* Add Github authentication

* Add AuthMethods endpoint

* Add frontend

* Rename basic authentication to password

* Add flags for configuring GitHub auth

* Remove name from API keys

* Fix authmethods in test

* Add stories and display auth methods error
This commit is contained in:
Kyle Carberry
2022-04-23 17:58:57 -05:00
committed by GitHub
parent 3976994781
commit 7496c3da81
41 changed files with 1251 additions and 422 deletions

View File

@ -42,6 +42,7 @@ type Options struct {
AWSCertificates awsidentity.Certificates
AzureCertificates x509.VerifyOptions
GoogleTokenValidator *idtoken.Validator
GithubOAuth2Config *GithubOAuth2Config
ICEServers []webrtc.ICEServer
SecureAuthCookie bool
SSHKeygenAlgorithm gitsshkey.Algorithm
@ -62,6 +63,9 @@ func New(options *Options) (http.Handler, func()) {
api := &api{
Options: options,
}
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
})
r := chi.NewRouter()
r.Route("/api/v2", func(r chi.Router) {
@ -86,7 +90,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/files", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
// This number is arbitrary, but reading/writing
// file content is expensive so it should be small.
httpmw.RateLimitPerMinute(12),
@ -96,7 +100,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/organizations/{organization}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractOrganizationParam(options.Database),
)
r.Get("/", api.organization)
@ -109,7 +113,7 @@ func New(options *Options) (http.Handler, func()) {
})
})
r.Route("/parameters/{scope}/{id}", func(r chi.Router) {
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
r.Use(apiKeyMiddleware)
r.Post("/", api.postParameter)
r.Get("/", api.parameters)
r.Route("/{name}", func(r chi.Router) {
@ -118,7 +122,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/templates/{template}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractTemplateParam(options.Database),
httpmw.ExtractOrganizationParam(options.Database),
)
@ -132,7 +136,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/templateversions/{templateversion}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractTemplateVersionParam(options.Database),
httpmw.ExtractOrganizationParam(options.Database),
)
@ -154,8 +158,15 @@ func New(options *Options) (http.Handler, func()) {
r.Post("/first", api.postFirstUser)
r.Post("/login", api.postLogin)
r.Post("/logout", api.postLogout)
r.Get("/authmethods", api.userAuthMethods)
r.Route("/oauth2", func(r chi.Router) {
r.Route("/github", func(r chi.Router) {
r.Use(httpmw.ExtractOAuth2(options.GithubOAuth2Config))
r.Get("/callback", api.userOAuth2Github)
})
})
r.Group(func(r chi.Router) {
r.Use(httpmw.ExtractAPIKey(options.Database, nil))
r.Use(apiKeyMiddleware)
r.Post("/", api.postUsers)
r.Get("/", api.users)
r.Route("/{user}", func(r chi.Router) {
@ -193,7 +204,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceAgentParam(options.Database),
)
r.Get("/", api.workspaceAgent)
@ -204,7 +215,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceResourceParam(options.Database),
httpmw.ExtractWorkspaceParam(options.Database),
)
@ -212,7 +223,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/workspaces/{workspace}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspace)
@ -230,7 +241,7 @@ func New(options *Options) (http.Handler, func()) {
})
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
apiKeyMiddleware,
httpmw.ExtractWorkspaceBuildParam(options.Database),
httpmw.ExtractWorkspaceParam(options.Database),
)

View File

@ -53,6 +53,7 @@ import (
type Options struct {
AWSCertificates awsidentity.Certificates
AzureCertificates x509.VerifyOptions
GithubOAuth2Config *coderd.GithubOAuth2Config
GoogleTokenValidator *idtoken.Validator
SSHKeygenAlgorithm gitsshkey.Algorithm
APIRateLimit int
@ -123,6 +124,7 @@ func New(t *testing.T, options *Options) *codersdk.Client {
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
TURNServer: turnServer,

View File

@ -434,6 +434,16 @@ func (q *fakeQuerier) GetWorkspacesByUserID(_ context.Context, req database.GetW
return workspaces, nil
}
func (q *fakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
if len(q.organizations) == 0 {
return nil, sql.ErrNoRows
}
return q.organizations, nil
}
func (q *fakeQuerier) GetOrganizationByID(_ context.Context, id uuid.UUID) (database.Organization, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -856,21 +866,18 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP
//nolint:gosimple
key := database.APIKey{
ID: arg.ID,
HashedSecret: arg.HashedSecret,
UserID: arg.UserID,
Application: arg.Application,
Name: arg.Name,
LastUsed: arg.LastUsed,
ExpiresAt: arg.ExpiresAt,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
LoginType: arg.LoginType,
OIDCAccessToken: arg.OIDCAccessToken,
OIDCRefreshToken: arg.OIDCRefreshToken,
OIDCIDToken: arg.OIDCIDToken,
OIDCExpiry: arg.OIDCExpiry,
DevurlToken: arg.DevurlToken,
ID: arg.ID,
HashedSecret: arg.HashedSecret,
UserID: arg.UserID,
ExpiresAt: arg.ExpiresAt,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
LastUsed: arg.LastUsed,
LoginType: arg.LoginType,
OAuthAccessToken: arg.OAuthAccessToken,
OAuthRefreshToken: arg.OAuthRefreshToken,
OAuthIDToken: arg.OAuthIDToken,
OAuthExpiry: arg.OAuthExpiry,
}
q.apiKeys = append(q.apiKeys, key)
return key, nil
@ -1185,9 +1192,9 @@ func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPI
}
apiKey.LastUsed = arg.LastUsed
apiKey.ExpiresAt = arg.ExpiresAt
apiKey.OIDCAccessToken = arg.OIDCAccessToken
apiKey.OIDCRefreshToken = arg.OIDCRefreshToken
apiKey.OIDCExpiry = arg.OIDCExpiry
apiKey.OAuthAccessToken = arg.OAuthAccessToken
apiKey.OAuthRefreshToken = arg.OAuthRefreshToken
apiKey.OAuthExpiry = arg.OAuthExpiry
q.apiKeys[index] = apiKey
return nil
}

View File

@ -14,9 +14,8 @@ CREATE TYPE log_source AS ENUM (
);
CREATE TYPE login_type AS ENUM (
'built-in',
'saml',
'oidc'
'password',
'github'
);
CREATE TYPE parameter_destination_scheme AS ENUM (
@ -67,18 +66,15 @@ CREATE TABLE api_keys (
id text NOT NULL,
hashed_secret bytea NOT NULL,
user_id uuid NOT NULL,
application boolean NOT NULL,
name text NOT NULL,
last_used timestamp with time zone NOT NULL,
expires_at timestamp with time zone NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
login_type login_type NOT NULL,
oidc_access_token text DEFAULT ''::text NOT NULL,
oidc_refresh_token text DEFAULT ''::text NOT NULL,
oidc_id_token text DEFAULT ''::text NOT NULL,
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
devurl_token boolean DEFAULT false NOT NULL
oauth_access_token text DEFAULT ''::text NOT NULL,
oauth_refresh_token text DEFAULT ''::text NOT NULL,
oauth_id_token text DEFAULT ''::text NOT NULL,
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL
);
CREATE TABLE files (

View File

@ -4,14 +4,9 @@
-- All tables and types are stolen from:
-- https://github.com/coder/m/blob/47b6fc383347b9f9fab424d829c482defd3e1fe2/product/coder/pkg/database/dump.sql
--
-- Name: users; Type: TABLE; Schema: public; Owner: coder
--
CREATE TYPE login_type AS ENUM (
'built-in',
'saml',
'oidc'
'password',
'github'
);
CREATE TABLE IF NOT EXISTS users (
@ -31,10 +26,6 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users USING btree (email);
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users USING btree (username);
CREATE UNIQUE INDEX IF NOT EXISTS users_username_lower_idx ON users USING btree (lower(username));
--
-- Name: organizations; Type: TABLE; Schema: Owner: coder
--
CREATE TABLE IF NOT EXISTS organizations (
id uuid NOT NULL,
name text NOT NULL,
@ -68,18 +59,15 @@ CREATE TABLE IF NOT EXISTS api_keys (
id text NOT NULL,
hashed_secret bytea NOT NULL,
user_id uuid NOT NULL,
application boolean NOT NULL,
name text NOT NULL,
last_used timestamp with time zone NOT NULL,
expires_at timestamp with time zone NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
login_type login_type NOT NULL,
oidc_access_token text DEFAULT ''::text NOT NULL,
oidc_refresh_token text DEFAULT ''::text NOT NULL,
oidc_id_token text DEFAULT ''::text NOT NULL,
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
devurl_token boolean DEFAULT false NOT NULL,
oauth_access_token text DEFAULT ''::text NOT NULL,
oauth_refresh_token text DEFAULT ''::text NOT NULL,
oauth_id_token text DEFAULT ''::text NOT NULL,
oauth_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
PRIMARY KEY (id)
);

View File

@ -56,9 +56,8 @@ func (e *LogSource) Scan(src interface{}) error {
type LoginType string
const (
LoginTypeBuiltIn LoginType = "built-in"
LoginTypeSaml LoginType = "saml"
LoginTypeOIDC LoginType = "oidc"
LoginTypePassword LoginType = "password"
LoginTypeGithub LoginType = "github"
)
func (e *LoginType) Scan(src interface{}) error {
@ -230,21 +229,18 @@ func (e *WorkspaceTransition) Scan(src interface{}) error {
}
type APIKey struct {
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Application bool `db:"application" json:"application"`
Name string `db:"name" json:"name"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
type File struct {

View File

@ -18,6 +18,7 @@ type querier interface {
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
GetOrganizations(ctx context.Context) ([]Organization, error)
GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error)
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error)

View File

@ -15,7 +15,7 @@ import (
const getAPIKeyByID = `-- name: GetAPIKeyByID :one
SELECT
id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry
FROM
api_keys
WHERE
@ -31,18 +31,15 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro
&i.ID,
&i.HashedSecret,
&i.UserID,
&i.Application,
&i.Name,
&i.LastUsed,
&i.ExpiresAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.LoginType,
&i.OIDCAccessToken,
&i.OIDCRefreshToken,
&i.OIDCIDToken,
&i.OIDCExpiry,
&i.DevurlToken,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthIDToken,
&i.OAuthExpiry,
)
return i, err
}
@ -53,55 +50,33 @@ INSERT INTO
id,
hashed_secret,
user_id,
application,
"name",
last_used,
expires_at,
created_at,
updated_at,
login_type,
oidc_access_token,
oidc_refresh_token,
oidc_id_token,
oidc_expiry,
devurl_token
oauth_access_token,
oauth_refresh_token,
oauth_id_token,
oauth_expiry
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
) RETURNING id, hashed_secret, user_id, application, name, last_used, expires_at, created_at, updated_at, login_type, oidc_access_token, oidc_refresh_token, oidc_id_token, oidc_expiry, devurl_token
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, oauth_access_token, oauth_refresh_token, oauth_id_token, oauth_expiry
`
type InsertAPIKeyParams struct {
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Application bool `db:"application" json:"application"`
Name string `db:"name" json:"name"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthIDToken string `db:"oauth_id_token" json:"oauth_id_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) {
@ -109,36 +84,30 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (
arg.ID,
arg.HashedSecret,
arg.UserID,
arg.Application,
arg.Name,
arg.LastUsed,
arg.ExpiresAt,
arg.CreatedAt,
arg.UpdatedAt,
arg.LoginType,
arg.OIDCAccessToken,
arg.OIDCRefreshToken,
arg.OIDCIDToken,
arg.OIDCExpiry,
arg.DevurlToken,
arg.OAuthAccessToken,
arg.OAuthRefreshToken,
arg.OAuthIDToken,
arg.OAuthExpiry,
)
var i APIKey
err := row.Scan(
&i.ID,
&i.HashedSecret,
&i.UserID,
&i.Application,
&i.Name,
&i.LastUsed,
&i.ExpiresAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.LoginType,
&i.OIDCAccessToken,
&i.OIDCRefreshToken,
&i.OIDCIDToken,
&i.OIDCExpiry,
&i.DevurlToken,
&i.OAuthAccessToken,
&i.OAuthRefreshToken,
&i.OAuthIDToken,
&i.OAuthExpiry,
)
return i, err
}
@ -149,20 +118,20 @@ UPDATE
SET
last_used = $2,
expires_at = $3,
oidc_access_token = $4,
oidc_refresh_token = $5,
oidc_expiry = $6
oauth_access_token = $4,
oauth_refresh_token = $5,
oauth_expiry = $6
WHERE
id = $1
`
type UpdateAPIKeyByIDParams struct {
ID string `db:"id" json:"id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
ID string `db:"id" json:"id"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
OAuthAccessToken string `db:"oauth_access_token" json:"oauth_access_token"`
OAuthRefreshToken string `db:"oauth_refresh_token" json:"oauth_refresh_token"`
OAuthExpiry time.Time `db:"oauth_expiry" json:"oauth_expiry"`
}
func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error {
@ -170,9 +139,9 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
arg.ID,
arg.LastUsed,
arg.ExpiresAt,
arg.OIDCAccessToken,
arg.OIDCRefreshToken,
arg.OIDCExpiry,
arg.OAuthAccessToken,
arg.OAuthRefreshToken,
arg.OAuthExpiry,
)
return err
}
@ -453,6 +422,42 @@ func (q *sqlQuerier) GetOrganizationByName(ctx context.Context, name string) (Or
return i, err
}
const getOrganizations = `-- name: GetOrganizations :many
SELECT
id, name, description, created_at, updated_at
FROM
organizations
`
func (q *sqlQuerier) GetOrganizations(ctx context.Context) ([]Organization, error) {
rows, err := q.db.QueryContext(ctx, getOrganizations)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Organization
for rows.Next() {
var i Organization
if err := rows.Scan(
&i.ID,
&i.Name,
&i.Description,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getOrganizationsByUserID = `-- name: GetOrganizationsByUserID :many
SELECT
id, name, description, created_at, updated_at

View File

@ -14,37 +14,18 @@ INSERT INTO
id,
hashed_secret,
user_id,
application,
"name",
last_used,
expires_at,
created_at,
updated_at,
login_type,
oidc_access_token,
oidc_refresh_token,
oidc_id_token,
oidc_expiry,
devurl_token
oauth_access_token,
oauth_refresh_token,
oauth_id_token,
oauth_expiry
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING *;
-- name: UpdateAPIKeyByID :exec
UPDATE
@ -52,8 +33,8 @@ UPDATE
SET
last_used = $2,
expires_at = $3,
oidc_access_token = $4,
oidc_refresh_token = $5,
oidc_expiry = $6
oauth_access_token = $4,
oauth_refresh_token = $5,
oauth_expiry = $6
WHERE
id = $1;

View File

@ -1,3 +1,9 @@
-- name: GetOrganizations :many
SELECT
*
FROM
organizations;
-- name: GetOrganizationByID :one
SELECT
*

View File

@ -21,10 +21,10 @@ overrides:
rename:
api_key: APIKey
login_type_oidc: LoginTypeOIDC
oidc_access_token: OIDCAccessToken
oidc_expiry: OIDCExpiry
oidc_id_token: OIDCIDToken
oidc_refresh_token: OIDCRefreshToken
oauth_access_token: OAuthAccessToken
oauth_expiry: OAuthExpiry
oauth_id_token: OAuthIDToken
oauth_refresh_token: OAuthRefreshToken
parameter_type_system_hcl: ParameterTypeSystemHCL
userstatus: UserStatus
gitsshkey: GitSSHKey

View File

@ -20,12 +20,6 @@ import (
// AuthCookie represents the name of the cookie the API key is stored in.
const AuthCookie = "session_token"
// OAuth2Config contains a subset of functions exposed from oauth2.Config.
// It is abstracted for simple testing.
type OAuth2Config interface {
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
}
type apiKeyContextKey struct{}
// APIKey returns the API key from the ExtractAPIKey handler.
@ -37,10 +31,16 @@ func APIKey(r *http.Request) database.APIKey {
return apiKey
}
// OAuth2Configs is a collection of configurations for OAuth-based authentication.
// This should be extended to support other authentication types in the future.
type OAuth2Configs struct {
Github OAuth2Config
}
// ExtractAPIKey requires authentication using a valid API key.
// It handles extending an API key if it comes close to expiry,
// updating the last used time in the database.
func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler {
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(AuthCookie)
@ -99,14 +99,24 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
// Tracks if the API key has properties updated!
changed := false
if key.LoginType == database.LoginTypeOIDC {
// Check if the OIDC token is expired!
if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() {
if key.LoginType != database.LoginTypePassword {
// Check if the OAuth token is expired!
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
var oauthConfig OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = oauth.Github
default:
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("unexpected authentication type %q", key.LoginType),
})
return
}
// If it is, let's refresh it from the provided config!
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: key.OIDCAccessToken,
RefreshToken: key.OIDCRefreshToken,
Expiry: key.OIDCExpiry,
AccessToken: key.OAuthAccessToken,
RefreshToken: key.OAuthRefreshToken,
Expiry: key.OAuthExpiry,
}).Token()
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
@ -114,9 +124,9 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
})
return
}
key.OIDCAccessToken = token.AccessToken
key.OIDCRefreshToken = token.RefreshToken
key.OIDCExpiry = token.Expiry
key.OAuthAccessToken = token.AccessToken
key.OAuthRefreshToken = token.RefreshToken
key.OAuthExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
}
@ -136,21 +146,20 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
changed = true
}
// Only update the ExpiresAt once an hour to prevent database spam.
// We extend the ExpiresAt to reduce reauthentication.
// We extend the ExpiresAt to reduce re-authentication.
apiKeyLifetime := 24 * time.Hour
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
if changed {
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
ID: key.ID,
ExpiresAt: key.ExpiresAt,
LastUsed: key.LastUsed,
OIDCAccessToken: key.OIDCAccessToken,
OIDCRefreshToken: key.OIDCRefreshToken,
OIDCExpiry: key.OIDCExpiry,
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
OAuthAccessToken: key.OAuthAccessToken,
OAuthRefreshToken: key.OAuthRefreshToken,
OAuthExpiry: key.OAuthExpiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{

View File

@ -189,7 +189,6 @@ func TestAPIKey(t *testing.T) {
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
@ -207,7 +206,6 @@ func TestAPIKey(t *testing.T) {
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
@ -277,7 +275,7 @@ func TestAPIKey(t *testing.T) {
require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCNotExpired", func(t *testing.T) {
t.Run("OAuthNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
@ -294,7 +292,7 @@ func TestAPIKey(t *testing.T) {
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LoginType: database.LoginTypeGithub,
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
@ -311,7 +309,7 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCRefresh", func(t *testing.T) {
t.Run("OAuthRefresh", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
@ -328,9 +326,9 @@ func TestAPIKey(t *testing.T) {
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LoginType: database.LoginTypeGithub,
LastUsed: database.Now(),
OIDCExpiry: database.Now().AddDate(0, 0, -1),
OAuthExpiry: database.Now().AddDate(0, 0, -1),
})
require.NoError(t, err)
token := &oauth2.Token{
@ -338,11 +336,11 @@ func TestAPIKey(t *testing.T) {
RefreshToken: "moo",
Expiry: database.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKey(db, &oauth2Config{
tokenSource: &oauth2TokenSource{
token: func() (*oauth2.Token, error) {
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{
Github: &oauth2Config{
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
return token, nil
},
}),
},
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
@ -354,22 +352,28 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken)
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
})
}
type oauth2Config struct {
tokenSource *oauth2TokenSource
tokenSource oauth2TokenSource
}
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}
type oauth2TokenSource struct {
token func() (*oauth2.Token, error)
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
return ""
}
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
return o.token()
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{}, nil
}
type oauth2TokenSource func() (*oauth2.Token, error)
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
return o()
}

132
coderd/httpmw/oauth2.go Normal file
View File

@ -0,0 +1,132 @@
package httpmw
import (
"context"
"fmt"
"net/http"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/cryptorand"
)
const (
oauth2StateCookieName = "oauth_state"
oauth2RedirectCookieName = "oauth_redirect"
)
type oauth2StateKey struct{}
type OAuth2State struct {
Token *oauth2.Token
Redirect string
}
// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing.
// *oauth2.Config should be used instead of implementing this in production.
type OAuth2Config interface {
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
}
// OAuth2 returns the state from an oauth request.
func OAuth2(r *http.Request) OAuth2State {
oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State)
if !ok {
panic("developer error: oauth middleware not provided")
}
return oauth
}
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
// URLs, and handling the exchange inbound. Any route that does not have
// a "code" URL parameter will be redirected.
func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if config == nil {
httpapi.Write(rw, http.StatusPreconditionRequired, httpapi.Response{
Message: fmt.Sprintf("The oauth2 method requested is not configured!"),
})
return
}
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" {
// If the code isn't provided, we'll redirect!
state, err := cryptorand.String(32)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("generate state string: %s", err),
})
return
}
http.SetCookie(rw, &http.Cookie{
Name: oauth2StateCookieName,
Value: state,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
})
// Redirect must always be specified, otherwise
// an old redirect could apply!
http.SetCookie(rw, &http.Cookie{
Name: oauth2RedirectCookieName,
Value: r.URL.Query().Get("redirect"),
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
})
http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
return
}
if state == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "state must be provided",
})
return
}
stateCookie, err := r.Cookie(oauth2StateCookieName)
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("%q cookie must be provided", oauth2StateCookieName),
})
return
}
if stateCookie.Value != state {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "state mismatched",
})
return
}
var redirect string
stateRedirect, err := r.Cookie(oauth2RedirectCookieName)
if err == nil {
redirect = stateRedirect.Value
}
oauthToken, err := config.Exchange(r.Context(), code)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("exchange oauth code: %s", err),
})
return
}
ctx := context.WithValue(r.Context(), oauth2StateKey{}, OAuth2State{
Token: oauthToken,
Redirect: redirect,
})
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,98 @@
package httpmw_test
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/httpmw"
)
type testOAuth2Provider struct {
}
func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "?state=" + url.QueryEscape(state)
}
func (*testOAuth2Provider) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "hello",
}, nil
}
func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
return nil
}
func TestOAuth2(t *testing.T) {
t.Parallel()
t.Run("NotSetup", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode)
})
t.Run("RedirectWithoutCode", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
if !assert.NotEmpty(t, location) {
return
}
require.Len(t, res.Result().Cookies(), 2)
cookie := res.Result().Cookies()[1]
require.Equal(t, "/dashboard", cookie.Value)
})
t.Run("NoState", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
})
t.Run("NoStateCookie", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("MismatchedState", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
req.AddCookie(&http.Cookie{
Name: "oauth_state",
Value: "mismatch",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("ExchangeCodeAndState", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=test&state=something", nil)
req.AddCookie(&http.Cookie{
Name: "oauth_state",
Value: "something",
})
req.AddCookie(&http.Cookie{
Name: "oauth_redirect",
Value: "/dashboard",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
require.Equal(t, "/dashboard", state.Redirect)
})).ServeHTTP(res, req)
})
}

View File

@ -41,7 +41,7 @@ func TestOrganizationParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestTemplateParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestTemplateVersionParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestWorkspaceParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

155
coderd/userauth.go Normal file
View File

@ -0,0 +1,155 @@
package coderd
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/google/go-github/v43/github"
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
)
// GithubOAuth2Provider exposes required functions for the Github authentication flow.
type GithubOAuth2Config struct {
httpmw.OAuth2Config
AuthenticatedUser func(ctx context.Context, client *http.Client) (*github.User, error)
ListEmails func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error)
ListOrganizationMemberships func(ctx context.Context, client *http.Client) ([]*github.Membership, error)
AllowSignups bool
AllowOrganizations []string
}
func (api *api) userAuthMethods(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusOK, codersdk.AuthMethods{
Password: true,
Github: api.GithubOAuth2Config != nil,
})
}
func (api *api) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
oauthClient := oauth2.NewClient(r.Context(), oauth2.StaticTokenSource(state.Token))
memberships, err := api.GithubOAuth2Config.ListOrganizationMemberships(r.Context(), oauthClient)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get authenticated github user organizations: %s", err),
})
return
}
var selectedMembership *github.Membership
for _, membership := range memberships {
for _, allowed := range api.GithubOAuth2Config.AllowOrganizations {
if *membership.Organization.Login != allowed {
continue
}
selectedMembership = membership
break
}
}
if selectedMembership == nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("You aren't a member of the authorized Github organizations!"),
})
return
}
emails, err := api.GithubOAuth2Config.ListEmails(r.Context(), oauthClient)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get personal github user: %s", err),
})
return
}
var user database.User
// Search for existing users with matching and verified emails.
// If a verified GitHub email matches a Coder user, we will return.
for _, email := range emails {
if email.Verified == nil {
continue
}
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
Email: *email.Email,
})
if errors.Is(err, sql.ErrNoRows) {
continue
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get user by email: %s", err),
})
return
}
if !*email.Verified {
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
Message: fmt.Sprintf("Verify the %q email address on Github to authenticate!", *email.Email),
})
return
}
break
}
// If the user doesn't exist, create a new one!
if user.ID == uuid.Nil {
if !api.GithubOAuth2Config.AllowSignups {
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
Message: "Signups are disabled for Github authentication!",
})
return
}
var organizationID uuid.UUID
organizations, _ := api.Database.GetOrganizations(r.Context())
if len(organizations) > 0 {
// Add the user to the first organization. Once multi-organization
// support is added, we should enable a configuration map of user
// email to organization.
organizationID = organizations[0].ID
}
ghUser, err := api.GithubOAuth2Config.AuthenticatedUser(r.Context(), oauthClient)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get authenticated github user: %s", err),
})
return
}
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
Email: *ghUser.Email,
Username: *ghUser.Login,
OrganizationID: organizationID,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("create user: %s", err),
})
return
}
}
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypeGithub,
OAuthAccessToken: state.Token.AccessToken,
OAuthRefreshToken: state.Token.RefreshToken,
OAuthExpiry: state.Token.Expiry,
})
if !created {
return
}
redirect := state.Redirect
if redirect == "" {
redirect = "/"
}
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
}

205
coderd/userauth_test.go Normal file
View File

@ -0,0 +1,205 @@
package coderd_test
import (
"context"
"net/http"
"net/url"
"testing"
"github.com/google/go-github/v43/github"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
)
type oauth2Config struct{}
func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "token",
}, nil
}
func (*oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return nil
}
func TestUserAuthMethods(t *testing.T) {
t.Parallel()
t.Run("Password", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
methods, err := client.AuthMethods(context.Background())
require.NoError(t, err)
require.True(t, methods.Password)
require.False(t, methods.Github)
})
t.Run("Github", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{},
})
methods, err := client.AuthMethods(context.Background())
require.NoError(t, err)
require.True(t, methods.Password)
require.True(t, methods.Github)
})
}
func TestUserOAuth2Github(t *testing.T) {
t.Parallel()
t.Run("NotInAllowedOrganization", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
Organization: &github.Organization{
Login: github.String("kyle"),
},
}}, nil
},
},
})
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("UnverifiedEmail", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
AllowOrganizations: []string{"coder"},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
Organization: &github.Organization{
Login: github.String("coder"),
},
}}, nil
},
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
return &github.User{}, nil
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
return []*github.UserEmail{{
Email: github.String("testuser@coder.com"),
Verified: github.Bool(false),
}}, nil
},
},
})
_ = coderdtest.CreateFirstUser(t, client)
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
})
t.Run("BlockSignups", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
AllowOrganizations: []string{"coder"},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
Organization: &github.Organization{
Login: github.String("coder"),
},
}}, nil
},
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
return &github.User{}, nil
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
return []*github.UserEmail{}, nil
},
},
})
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusForbidden, resp.StatusCode)
})
t.Run("Signup", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
AllowOrganizations: []string{"coder"},
AllowSignups: true,
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
Organization: &github.Organization{
Login: github.String("coder"),
},
}}, nil
},
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
return &github.User{
Login: github.String("kyle"),
Email: github.String("kyle@coder.com"),
}, nil
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
return []*github.UserEmail{}, nil
},
},
})
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
t.Run("Login", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
AllowOrganizations: []string{"coder"},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
Organization: &github.Organization{
Login: github.String("coder"),
},
}}, nil
},
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
return &github.User{}, nil
},
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
return []*github.UserEmail{{
Email: github.String("testuser@coder.com"),
Verified: github.Bool(true),
}}, nil
},
},
})
_ = coderdtest.CreateFirstUser(t, client)
resp := oauth2Callback(t, client)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
}
func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
state := "somestate"
oauthURL, err := client.URL.Parse("/api/v2/users/oauth2/github/callback?code=asd&state=" + state)
require.NoError(t, err)
req, err := http.NewRequest("GET", oauthURL.String(), nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{
Name: "oauth_state",
Value: state,
})
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)
t.Cleanup(func() {
_ = res.Body.Close()
})
return res
}

View File

@ -1,6 +1,7 @@
package coderd
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/json"
@ -71,66 +72,10 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) {
return
}
hashedPassword, err := userpassword.Hash(createUser.Password)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("hash password: %s", err.Error()),
})
return
}
// Create the user, organization, and membership to the user.
var user database.User
var organization database.Organization
err = api.Database.InTx(func(db database.Store) error {
user, err = api.Database.InsertUser(r.Context(), database.InsertUserParams{
ID: uuid.New(),
Email: createUser.Email,
HashedPassword: []byte(hashedPassword),
Username: createUser.Username,
LoginType: database.LoginTypeBuiltIn,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
if err != nil {
return xerrors.Errorf("create user: %w", err)
}
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
if err != nil {
return xerrors.Errorf("generate user gitsshkey: %w", err)
}
_, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
PrivateKey: privateKey,
PublicKey: publicKey,
})
if err != nil {
return xerrors.Errorf("insert user gitsshkey: %w", err)
}
organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.New(),
Name: createUser.OrganizationName,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
if err != nil {
return xerrors.Errorf("create organization: %w", err)
}
_, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
Roles: []string{"organization-admin"},
})
if err != nil {
return xerrors.Errorf("create organization member: %w", err)
}
return nil
user, organizationID, err := api.createUser(r.Context(), codersdk.CreateUserRequest{
Email: createUser.Email,
Username: createUser.Username,
Password: createUser.Password,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
@ -141,7 +86,7 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(rw, http.StatusCreated, codersdk.CreateFirstUserResponse{
UserID: user.ID,
OrganizationID: organization.ID,
OrganizationID: organizationID,
})
}
@ -262,56 +207,7 @@ func (api *api) postUsers(rw http.ResponseWriter, r *http.Request) {
return
}
hashedPassword, err := userpassword.Hash(createUser.Password)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("hash password: %s", err.Error()),
})
return
}
var user database.User
err = api.Database.InTx(func(db database.Store) error {
user, err = db.InsertUser(r.Context(), database.InsertUserParams{
ID: uuid.New(),
Email: createUser.Email,
HashedPassword: []byte(hashedPassword),
Username: createUser.Username,
LoginType: database.LoginTypeBuiltIn,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
if err != nil {
return xerrors.Errorf("create user: %w", err)
}
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
if err != nil {
return xerrors.Errorf("generate user gitsshkey: %w", err)
}
_, err = db.InsertGitSSHKey(r.Context(), database.InsertGitSSHKeyParams{
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
PrivateKey: privateKey,
PublicKey: publicKey,
})
if err != nil {
return xerrors.Errorf("insert user gitsshkey: %w", err)
}
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
Roles: []string{},
})
if err != nil {
return xerrors.Errorf("create organization member: %w", err)
}
return nil
})
user, _, err := api.createUser(r.Context(), createUser)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: err.Error(),
@ -542,41 +438,13 @@ func (api *api) postLogin(rw http.ResponseWriter, r *http.Request) {
return
}
keyID, keySecret, err := generateAPIKeyIDSecret()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
})
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
if !created {
return
}
hashed := sha256.Sum256([]byte(keySecret))
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: keyID,
UserID: user.ID,
ExpiresAt: database.Now().Add(24 * time.Hour),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
HashedSecret: hashed[:],
LoginType: database.LoginTypeBuiltIn,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("insert api key: %s", err.Error()),
})
return
}
// This format is consumed by the APIKey middleware.
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
http.SetCookie(rw, &http.Cookie{
Name: httpmw.AuthCookie,
Value: sessionToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: api.SecureAuthCookie,
})
httpapi.Write(rw, http.StatusCreated, codersdk.LoginWithPasswordResponse{
SessionToken: sessionToken,
@ -595,35 +463,15 @@ func (api *api) postAPIKey(rw http.ResponseWriter, r *http.Request) {
return
}
keyID, keySecret, err := generateAPIKeyIDSecret()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
})
return
}
hashed := sha256.Sum256([]byte(keySecret))
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: keyID,
UserID: apiKey.UserID,
ExpiresAt: database.Now().AddDate(1, 0, 0), // Expire after 1 year (same as v1)
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
HashedSecret: hashed[:],
LoginType: database.LoginTypeBuiltIn,
sessionToken, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
UserID: user.ID,
LoginType: database.LoginTypePassword,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("insert api key: %s", err.Error()),
})
if !created {
return
}
// This format is consumed by the APIKey middleware.
generatedAPIKey := fmt.Sprintf("%s-%s", keyID, keySecret)
httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: generatedAPIKey})
httpapi.Write(rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: sessionToken})
}
// Clear the user's session cookie
@ -984,6 +832,117 @@ func generateAPIKeyIDSecret() (id string, secret string, err error) {
return id, secret, nil
}
func (api *api) createAPIKey(rw http.ResponseWriter, r *http.Request, params database.InsertAPIKeyParams) (string, bool) {
keyID, keySecret, err := generateAPIKeyIDSecret()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("generate api key parts: %s", err.Error()),
})
return "", false
}
hashed := sha256.Sum256([]byte(keySecret))
_, err = api.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: keyID,
UserID: params.UserID,
ExpiresAt: database.Now().Add(24 * time.Hour),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
HashedSecret: hashed[:],
LoginType: params.LoginType,
OAuthAccessToken: params.OAuthAccessToken,
OAuthRefreshToken: params.OAuthRefreshToken,
OAuthIDToken: params.OAuthIDToken,
OAuthExpiry: params.OAuthExpiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("insert api key: %s", err.Error()),
})
return "", false
}
// This format is consumed by the APIKey middleware.
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
http.SetCookie(rw, &http.Cookie{
Name: httpmw.AuthCookie,
Value: sessionToken,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: api.SecureAuthCookie,
})
return sessionToken, true
}
func (api *api) createUser(ctx context.Context, req codersdk.CreateUserRequest) (database.User, uuid.UUID, error) {
var user database.User
return user, req.OrganizationID, api.Database.InTx(func(db database.Store) error {
// If no organization is provided, create a new one for the user.
if req.OrganizationID == uuid.Nil {
organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
ID: uuid.New(),
Name: req.Username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
if err != nil {
return xerrors.Errorf("create organization: %w", err)
}
req.OrganizationID = organization.ID
}
params := database.InsertUserParams{
ID: uuid.New(),
Email: req.Email,
Username: req.Username,
LoginType: database.LoginTypePassword,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
}
// If a user signs up with OAuth, they can have no password!
if req.Password != "" {
hashedPassword, err := userpassword.Hash(req.Password)
if err != nil {
return xerrors.Errorf("hash password: %w", err)
}
params.HashedPassword = []byte(hashedPassword)
}
var err error
user, err = db.InsertUser(ctx, params)
if err != nil {
return xerrors.Errorf("create user: %w", err)
}
privateKey, publicKey, err := gitsshkey.Generate(api.SSHKeygenAlgorithm)
if err != nil {
return xerrors.Errorf("generate user gitsshkey: %w", err)
}
_, err = db.InsertGitSSHKey(ctx, database.InsertGitSSHKeyParams{
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
PrivateKey: privateKey,
PublicKey: publicKey,
})
if err != nil {
return xerrors.Errorf("insert user gitsshkey: %w", err)
}
_, err = db.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
OrganizationID: req.OrganizationID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
Roles: []string{},
})
if err != nil {
return xerrors.Errorf("create organization member: %w", err)
}
return nil
})
}
func convertUser(user database.User) codersdk.User {
return codersdk.User{
ID: user.ID,

View File

@ -241,13 +241,14 @@ func TestUpdateUserProfile(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
existentUser, _ := client.CreateUser(context.Background(), codersdk.CreateUserRequest{
existentUser, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{
Email: "bruno@coder.com",
Username: "bruno",
Password: "password",
OrganizationID: user.OrganizationID,
})
_, err := client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{
require.NoError(t, err)
_, err = client.UpdateUserProfile(context.Background(), codersdk.Me, codersdk.UpdateUserProfileRequest{
Username: existentUser.Username,
Email: "newemail@coder.com",
})