mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
feat: Add OIDC authentication (#3314)
* feat: Add OIDC authentication * Extract username into a separate package and add OIDC tests * Add test case for invalid tokens * Add test case for username as email * Add OIDC to the frontend * Improve comments from self-review * Add authentication docs * Add telemetry * Update docs/install/auth.md Co-authored-by: Ammar Bandukwala <ammar@ammar.io> * Update docs/install/auth.md Co-authored-by: Ammar Bandukwala <ammar@ammar.io> * Remove username package Co-authored-by: Ammar Bandukwala <ammar@ammar.io>
This commit is contained in:
@ -57,6 +57,7 @@ type Options struct {
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
GithubOAuth2Config *GithubOAuth2Config
|
||||
OIDCConfig *OIDCConfig
|
||||
ICEServers []webrtc.ICEServer
|
||||
SecureAuthCookie bool
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
@ -105,6 +106,7 @@ func New(options *Options) *API {
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
|
||||
oauthConfigs := &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
OIDC: options.OIDCConfig,
|
||||
}
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
|
||||
|
||||
@ -259,6 +261,10 @@ func New(options *Options) *API {
|
||||
r.Get("/callback", api.userOAuth2Github)
|
||||
})
|
||||
})
|
||||
r.Route("/oidc/callback", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOAuth2(options.OIDCConfig))
|
||||
r.Get("/", api.userOIDC)
|
||||
})
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
|
@ -248,6 +248,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
|
||||
// Has it's own auth
|
||||
"GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true},
|
||||
"GET:/api/v2/users/oidc/callback": {NoAuthorize: true},
|
||||
|
||||
// All workspaceagents endpoints do not use rbac
|
||||
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
|
||||
|
@ -63,6 +63,7 @@ type Options struct {
|
||||
Authorizer rbac.Authorizer
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GithubOAuth2Config *coderd.GithubOAuth2Config
|
||||
OIDCConfig *coderd.OIDCConfig
|
||||
GoogleTokenValidator *idtoken.Validator
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
APIRateLimit int
|
||||
@ -189,6 +190,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer)
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
OIDCConfig: options.OIDCConfig,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
TURNServer: turnServer,
|
||||
|
3
coderd/database/dump.sql
generated
3
coderd/database/dump.sql
generated
@ -27,7 +27,8 @@ CREATE TYPE log_source AS ENUM (
|
||||
|
||||
CREATE TYPE login_type AS ENUM (
|
||||
'password',
|
||||
'github'
|
||||
'github',
|
||||
'oidc'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_destination_scheme AS ENUM (
|
||||
|
@ -31,6 +31,11 @@ func main() {
|
||||
}
|
||||
|
||||
cmd := exec.Command(
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"--network=host",
|
||||
"postgres:13",
|
||||
"pg_dump",
|
||||
"--schema-only",
|
||||
connection,
|
||||
|
7
coderd/database/migrations/000032_oidc.down.sql
Normal file
7
coderd/database/migrations/000032_oidc.down.sql
Normal file
@ -0,0 +1,7 @@
|
||||
CREATE TYPE old_login_type AS ENUM (
|
||||
'password',
|
||||
'github'
|
||||
);
|
||||
ALTER TABLE api_keys ALTER COLUMN login_type TYPE old_login_type USING (login_type::text::old_login_type);
|
||||
DROP TYPE login_type;
|
||||
ALTER TYPE old_login_type RENAME TO login_type;
|
8
coderd/database/migrations/000032_oidc.up.sql
Normal file
8
coderd/database/migrations/000032_oidc.up.sql
Normal file
@ -0,0 +1,8 @@
|
||||
CREATE TYPE new_login_type AS ENUM (
|
||||
'password',
|
||||
'github',
|
||||
'oidc'
|
||||
);
|
||||
ALTER TABLE api_keys ALTER COLUMN login_type TYPE new_login_type USING (login_type::text::new_login_type);
|
||||
DROP TYPE login_type;
|
||||
ALTER TYPE new_login_type RENAME TO login_type;
|
@ -101,6 +101,7 @@ type LoginType string
|
||||
const (
|
||||
LoginTypePassword LoginType = "password"
|
||||
LoginTypeGithub LoginType = "github"
|
||||
LoginTypeOIDC LoginType = "oidc"
|
||||
)
|
||||
|
||||
func (e *LoginType) Scan(src interface{}) error {
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
@ -16,8 +15,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
validate *validator.Validate
|
||||
usernameRegex = regexp.MustCompile("^[a-zA-Z0-9]+(?:-[a-zA-Z0-9]+)*$")
|
||||
validate *validator.Validate
|
||||
)
|
||||
|
||||
// This init is used to create a validator and register validation-specific
|
||||
@ -39,13 +37,7 @@ func init() {
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if len(str) > 32 {
|
||||
return false
|
||||
}
|
||||
if len(str) < 1 {
|
||||
return false
|
||||
}
|
||||
return usernameRegex.MatchString(str)
|
||||
return UsernameValid(str)
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -81,71 +81,6 @@ func TestRead(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadUsername(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Tests whether usernames are valid or not.
|
||||
testCases := []struct {
|
||||
Username string
|
||||
Valid bool
|
||||
}{
|
||||
{"1", true},
|
||||
{"12", true},
|
||||
{"123", true},
|
||||
{"12345678901234567890", true},
|
||||
{"123456789012345678901", true},
|
||||
{"a", true},
|
||||
{"a1", true},
|
||||
{"a1b2", true},
|
||||
{"a1b2c3d4e5f6g7h8i9j0", true},
|
||||
{"a1b2c3d4e5f6g7h8i9j0k", true},
|
||||
{"aa", true},
|
||||
{"abc", true},
|
||||
{"abcdefghijklmnopqrst", true},
|
||||
{"abcdefghijklmnopqrstu", true},
|
||||
{"wow-test", true},
|
||||
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{" a", false},
|
||||
{" a ", false},
|
||||
{" 1", false},
|
||||
{"1 ", false},
|
||||
{" aa", false},
|
||||
{"aa ", false},
|
||||
{" 12", false},
|
||||
{"12 ", false},
|
||||
{" a1", false},
|
||||
{"a1 ", false},
|
||||
{" abcdefghijklmnopqrstu", false},
|
||||
{"abcdefghijklmnopqrstu ", false},
|
||||
{" 123456789012345678901", false},
|
||||
{" a1b2c3d4e5f6g7h8i9j0k", false},
|
||||
{"a1b2c3d4e5f6g7h8i9j0k ", false},
|
||||
{"bananas_wow", false},
|
||||
{"test--now", false},
|
||||
|
||||
{"123456789012345678901234567890123", false},
|
||||
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
|
||||
{"123456789012345678901234567890123123456789012345678901234567890123", false},
|
||||
}
|
||||
type toValidate struct {
|
||||
Username string `json:"username" validate:"username"`
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.Username, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
data, err := json.Marshal(toValidate{testCase.Username})
|
||||
require.NoError(t, err)
|
||||
r := httptest.NewRequest("POST", "/", bytes.NewBuffer(data))
|
||||
|
||||
var validate toValidate
|
||||
require.Equal(t, testCase.Valid, httpapi.Read(rw, r, &validate))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func WebsocketCloseMsg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
45
coderd/httpapi/username.go
Normal file
45
coderd/httpapi/username.go
Normal file
@ -0,0 +1,45 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
)
|
||||
|
||||
var (
|
||||
usernameValid = regexp.MustCompile("^[a-zA-Z0-9]+(?:-[a-zA-Z0-9]+)*$")
|
||||
usernameReplace = regexp.MustCompile("[^a-zA-Z0-9-]*")
|
||||
)
|
||||
|
||||
// UsernameValid returns whether the input string is a valid username.
|
||||
func UsernameValid(str string) bool {
|
||||
if len(str) > 32 {
|
||||
return false
|
||||
}
|
||||
if len(str) < 1 {
|
||||
return false
|
||||
}
|
||||
return usernameValid.MatchString(str)
|
||||
}
|
||||
|
||||
// UsernameFrom returns a best-effort username from the provided string.
|
||||
//
|
||||
// It first attempts to validate the incoming string, which will
|
||||
// be returned if it is valid. It then will attempt to extract
|
||||
// the username from an email address. If no success happens during
|
||||
// these steps, a random username will be returned.
|
||||
func UsernameFrom(str string) string {
|
||||
if UsernameValid(str) {
|
||||
return str
|
||||
}
|
||||
emailAt := strings.LastIndex(str, "@")
|
||||
if emailAt >= 0 {
|
||||
str = str[:emailAt]
|
||||
}
|
||||
str = usernameReplace.ReplaceAllString(str, "")
|
||||
if UsernameValid(str) {
|
||||
return str
|
||||
}
|
||||
return strings.ReplaceAll(namesgenerator.GetRandomName(1), "_", "-")
|
||||
}
|
102
coderd/httpapi/username_test.go
Normal file
102
coderd/httpapi/username_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package httpapi_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
func TestValid(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Tests whether usernames are valid or not.
|
||||
testCases := []struct {
|
||||
Username string
|
||||
Valid bool
|
||||
}{
|
||||
{"1", true},
|
||||
{"12", true},
|
||||
{"123", true},
|
||||
{"12345678901234567890", true},
|
||||
{"123456789012345678901", true},
|
||||
{"a", true},
|
||||
{"a1", true},
|
||||
{"a1b2", true},
|
||||
{"a1b2c3d4e5f6g7h8i9j0", true},
|
||||
{"a1b2c3d4e5f6g7h8i9j0k", true},
|
||||
{"aa", true},
|
||||
{"abc", true},
|
||||
{"abcdefghijklmnopqrst", true},
|
||||
{"abcdefghijklmnopqrstu", true},
|
||||
{"wow-test", true},
|
||||
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{" a", false},
|
||||
{" a ", false},
|
||||
{" 1", false},
|
||||
{"1 ", false},
|
||||
{" aa", false},
|
||||
{"aa ", false},
|
||||
{" 12", false},
|
||||
{"12 ", false},
|
||||
{" a1", false},
|
||||
{"a1 ", false},
|
||||
{" abcdefghijklmnopqrstu", false},
|
||||
{"abcdefghijklmnopqrstu ", false},
|
||||
{" 123456789012345678901", false},
|
||||
{" a1b2c3d4e5f6g7h8i9j0k", false},
|
||||
{"a1b2c3d4e5f6g7h8i9j0k ", false},
|
||||
{"bananas_wow", false},
|
||||
{"test--now", false},
|
||||
|
||||
{"123456789012345678901234567890123", false},
|
||||
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
|
||||
{"123456789012345678901234567890123123456789012345678901234567890123", false},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.Username, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, testCase.Valid, httpapi.UsernameValid(testCase.Username))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
From string
|
||||
Match string
|
||||
}{
|
||||
{"1", "1"},
|
||||
{"kyle@kwc.io", "kyle"},
|
||||
{"kyle+wow@kwc.io", "kylewow"},
|
||||
{"kyle+testing", "kyletesting"},
|
||||
{"kyle-testing", "kyle-testing"},
|
||||
{"much.”more unusual”@example.com", "muchmoreunusual"},
|
||||
|
||||
// Cases where an invalid string is provided, and the result is a random name.
|
||||
{"123456789012345678901234567890123", ""},
|
||||
{"very.unusual.”@”.unusual.com@example.com", ""},
|
||||
{"___@ok.com", ""},
|
||||
{" something with spaces ", ""},
|
||||
{"--test--", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.From, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
converted := httpapi.UsernameFrom(testCase.From)
|
||||
t.Log(converted)
|
||||
require.True(t, httpapi.UsernameValid(converted))
|
||||
if testCase.Match == "" {
|
||||
require.NotEqual(t, testCase.From, converted)
|
||||
} else {
|
||||
require.Equal(t, testCase.Match, converted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -49,6 +49,7 @@ func AuthorizationUserRoles(r *http.Request) database.GetAuthorizationUserRolesR
|
||||
// This should be extended to support other authentication types in the future.
|
||||
type OAuth2Configs struct {
|
||||
Github OAuth2Config
|
||||
OIDC OAuth2Config
|
||||
}
|
||||
|
||||
const (
|
||||
@ -155,6 +156,8 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool
|
||||
switch key.LoginType {
|
||||
case database.LoginTypeGithub:
|
||||
oauthConfig = oauth.Github
|
||||
case database.LoginTypeOIDC:
|
||||
oauthConfig = oauth.OIDC
|
||||
default:
|
||||
write(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
|
@ -41,6 +41,8 @@ type Options struct {
|
||||
BuiltinPostgres bool
|
||||
DeploymentID string
|
||||
GitHubOAuth bool
|
||||
OIDCAuth bool
|
||||
OIDCIssuerURL string
|
||||
Prometheus bool
|
||||
STUN bool
|
||||
SnapshotFrequency time.Duration
|
||||
@ -229,6 +231,8 @@ func (r *remoteReporter) deployment() error {
|
||||
BuiltinPostgres: r.options.BuiltinPostgres,
|
||||
Containerized: containerized,
|
||||
GitHubOAuth: r.options.GitHubOAuth,
|
||||
OIDCAuth: r.options.OIDCAuth,
|
||||
OIDCIssuerURL: r.options.OIDCIssuerURL,
|
||||
Prometheus: r.options.Prometheus,
|
||||
STUN: r.options.STUN,
|
||||
Tunnel: r.options.Tunnel,
|
||||
@ -601,6 +605,8 @@ type Deployment struct {
|
||||
Containerized bool `json:"containerized"`
|
||||
Tunnel bool `json:"tunnel"`
|
||||
GitHubOAuth bool `json:"github_oauth"`
|
||||
OIDCAuth bool `json:"oidc_auth"`
|
||||
OIDCIssuerURL string `json:"oidc_issuer_url"`
|
||||
Prometheus bool `json:"prometheus"`
|
||||
STUN bool `json:"stun"`
|
||||
OSType string `json:"os_type"`
|
||||
|
@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
@ -40,6 +42,7 @@ func (api *API) userAuthMethods(rw http.ResponseWriter, _ *http.Request) {
|
||||
httpapi.Write(rw, http.StatusOK, codersdk.AuthMethods{
|
||||
Password: true,
|
||||
Github: api.GithubOAuth2Config != nil,
|
||||
OIDC: api.OIDCConfig != nil,
|
||||
})
|
||||
}
|
||||
|
||||
@ -205,3 +208,137 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
httpmw.OAuth2Config
|
||||
|
||||
Verifier *oidc.IDTokenVerifier
|
||||
// EmailDomain is the domain to enforce when a user authenticates.
|
||||
EmailDomain string
|
||||
AllowSignups bool
|
||||
}
|
||||
|
||||
func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
||||
state := httpmw.OAuth2(r)
|
||||
|
||||
// See the example here: https://github.com/coreos/go-oidc
|
||||
rawIDToken, ok := state.Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "id_token not found in response payload. Ensure your OIDC callback is configured correctly!",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := api.OIDCConfig.Verifier.Verify(r.Context(), rawIDToken)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to verify OIDC token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"email_verified"`
|
||||
Username string `json:"preferred_username"`
|
||||
}
|
||||
err = idToken.Claims(&claims)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to extract OIDC claims.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if claims.Email == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No email found in OIDC payload!",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !claims.Verified {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("Verify the %q email address on your OIDC provider to authenticate!", claims.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
// The username is a required property in Coder. We make a best-effort
|
||||
// attempt at using what the claims provide, but if that fails we will
|
||||
// generate a random username.
|
||||
if !httpapi.UsernameValid(claims.Username) {
|
||||
// If no username is provided, we can default to use the email address.
|
||||
// This will be converted in the from function below, so it's safe
|
||||
// to keep the domain.
|
||||
if claims.Username == "" {
|
||||
claims.Username = claims.Email
|
||||
}
|
||||
claims.Username = httpapi.UsernameFrom(claims.Username)
|
||||
}
|
||||
if api.OIDCConfig.EmailDomain != "" {
|
||||
if !strings.HasSuffix(claims.Email, api.OIDCConfig.EmailDomain) {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("Your email %q is not a part of the %q domain!", claims.Email, api.OIDCConfig.EmailDomain),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var user database.User
|
||||
user, err = api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
Email: claims.Email,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if !api.OIDCConfig.AllowSignups {
|
||||
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Signups are disabled for OIDC authentication!",
|
||||
})
|
||||
return
|
||||
}
|
||||
var organizationID uuid.UUID
|
||||
organizations, _ := api.Database.GetOrganizations(r.Context())
|
||||
if len(organizations) > 0 {
|
||||
// Add the user to the first organization. Once multi-organization
|
||||
// support is added, we should enable a configuration map of user
|
||||
// email to organization.
|
||||
organizationID = organizations[0].ID
|
||||
}
|
||||
user, _, err = api.createUser(r.Context(), codersdk.CreateUserRequest{
|
||||
Email: claims.Email,
|
||||
Username: claims.Username,
|
||||
OrganizationID: organizationID,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating user.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get user by email.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, created := api.createAPIKey(rw, r, database.InsertAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthAccessToken: state.Token.AccessToken,
|
||||
OAuthRefreshToken: state.Token.RefreshToken,
|
||||
OAuthExpiry: state.Token.Expiry,
|
||||
})
|
||||
if !created {
|
||||
return
|
||||
}
|
||||
|
||||
redirect := state.Redirect
|
||||
if redirect == "" {
|
||||
redirect = "/"
|
||||
}
|
||||
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
@ -2,11 +2,19 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
@ -16,13 +24,18 @@ import (
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
type oauth2Config struct{}
|
||||
type oauth2Config struct {
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
func (o *oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
if o.token != nil {
|
||||
return o.token, nil
|
||||
}
|
||||
return &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}, nil
|
||||
@ -249,6 +262,169 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserOIDC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
Name string
|
||||
Claims jwt.MapClaims
|
||||
AllowSignups bool
|
||||
EmailDomain string
|
||||
Username string
|
||||
StatusCode int
|
||||
}{{
|
||||
Name: "EmailNotVerified",
|
||||
Claims: jwt.MapClaims{
|
||||
"email": "kyle@kwc.io",
|
||||
},
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusForbidden,
|
||||
}, {
|
||||
Name: "NotInRequiredEmailDomain",
|
||||
Claims: jwt.MapClaims{
|
||||
"email": "kyle@kwc.io",
|
||||
"email_verified": true,
|
||||
},
|
||||
AllowSignups: true,
|
||||
EmailDomain: "coder.com",
|
||||
StatusCode: http.StatusForbidden,
|
||||
}, {
|
||||
Name: "EmptyClaims",
|
||||
Claims: jwt.MapClaims{},
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}, {
|
||||
Name: "NoSignups",
|
||||
Claims: jwt.MapClaims{
|
||||
"email": "kyle@kwc.io",
|
||||
"email_verified": true,
|
||||
},
|
||||
StatusCode: http.StatusForbidden,
|
||||
}, {
|
||||
Name: "UsernameFromEmail",
|
||||
Claims: jwt.MapClaims{
|
||||
"email": "kyle@kwc.io",
|
||||
"email_verified": true,
|
||||
},
|
||||
Username: "kyle",
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
}, {
|
||||
Name: "UsernameFromClaims",
|
||||
Claims: jwt.MapClaims{
|
||||
"email": "kyle@kwc.io",
|
||||
"email_verified": true,
|
||||
"preferred_username": "hotdog",
|
||||
},
|
||||
Username: "hotdog",
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
}, {
|
||||
// Services like Okta return the email as the username:
|
||||
// https://developer.okta.com/docs/reference/api/oidc/#base-claims-always-present
|
||||
Name: "UsernameAsEmail",
|
||||
Claims: jwt.MapClaims{
|
||||
"email": "kyle@kwc.io",
|
||||
"email_verified": true,
|
||||
"preferred_username": "kyle@kwc.io",
|
||||
},
|
||||
Username: "kyle",
|
||||
AllowSignups: true,
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
}} {
|
||||
tc := tc
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := createOIDCConfig(t, tc.Claims)
|
||||
config.AllowSignups = tc.AllowSignups
|
||||
config.EmailDomain = tc.EmailDomain
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: config,
|
||||
})
|
||||
resp := oidcCallback(t, client)
|
||||
assert.Equal(t, tc.StatusCode, resp.StatusCode)
|
||||
|
||||
if tc.Username != "" {
|
||||
client.SessionToken = resp.Cookies()[0].Value
|
||||
user, err := client.User(context.Background(), "me")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.Username, user.Username)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Disabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
resp := oidcCallback(t, client)
|
||||
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NoIDToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: &coderd.OIDCConfig{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
},
|
||||
})
|
||||
resp := oidcCallback(t, client)
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("BadVerify", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{},
|
||||
}, &oidc.Config{})
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: &coderd.OIDCConfig{
|
||||
OAuth2Config: &oauth2Config{
|
||||
token: (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": "invalid",
|
||||
}),
|
||||
},
|
||||
Verifier: verifier,
|
||||
},
|
||||
})
|
||||
resp := oidcCallback(t, client)
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
// createOIDCConfig generates a new OIDCConfig that returns a static token
|
||||
// with the claims provided.
|
||||
func createOIDCConfig(t *testing.T, claims jwt.MapClaims) *coderd.OIDCConfig {
|
||||
t.Helper()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
|
||||
claims["exp"] = time.Now().Add(time.Hour).UnixMilli()
|
||||
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, claims).SignedString(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifier := oidc.NewVerifier("", &oidc.StaticKeySet{
|
||||
PublicKeys: []crypto.PublicKey{key.Public()},
|
||||
}, &oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
|
||||
return &coderd.OIDCConfig{
|
||||
OAuth2Config: &oauth2Config{
|
||||
token: (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": signed,
|
||||
}),
|
||||
},
|
||||
Verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
@ -269,3 +445,26 @@ func oauth2Callback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
})
|
||||
return res
|
||||
}
|
||||
|
||||
func oidcCallback(t *testing.T, client *codersdk.Client) *http.Response {
|
||||
t.Helper()
|
||||
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
state := "somestate"
|
||||
oauthURL, err := client.URL.Parse("/api/v2/users/oidc/callback?code=asd&state=" + state)
|
||||
require.NoError(t, err)
|
||||
req, err := http.NewRequest("GET", oauthURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Value: state,
|
||||
})
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
data, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
t.Log(string(data))
|
||||
return res
|
||||
}
|
||||
|
Reference in New Issue
Block a user