feat: Add GitHub OAuth (#1050)

* Initial oauth

* Add Github authentication

* Add AuthMethods endpoint

* Add frontend

* Rename basic authentication to password

* Add flags for configuring GitHub auth

* Remove name from API keys

* Fix authmethods in test

* Add stories and display auth methods error
This commit is contained in:
Kyle Carberry
2022-04-23 17:58:57 -05:00
committed by GitHub
parent 3976994781
commit 7496c3da81
41 changed files with 1251 additions and 422 deletions

View File

@ -20,12 +20,6 @@ import (
// AuthCookie represents the name of the cookie the API key is stored in.
const AuthCookie = "session_token"
// OAuth2Config contains a subset of functions exposed from oauth2.Config.
// It is abstracted for simple testing.
type OAuth2Config interface {
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
}
type apiKeyContextKey struct{}
// APIKey returns the API key from the ExtractAPIKey handler.
@ -37,10 +31,16 @@ func APIKey(r *http.Request) database.APIKey {
return apiKey
}
// OAuth2Configs is a collection of configurations for OAuth-based authentication.
// This should be extended to support other authentication types in the future.
type OAuth2Configs struct {
Github OAuth2Config
}
// ExtractAPIKey requires authentication using a valid API key.
// It handles extending an API key if it comes close to expiry,
// updating the last used time in the database.
func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler {
func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(AuthCookie)
@ -99,14 +99,24 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
// Tracks if the API key has properties updated!
changed := false
if key.LoginType == database.LoginTypeOIDC {
// Check if the OIDC token is expired!
if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() {
if key.LoginType != database.LoginTypePassword {
// Check if the OAuth token is expired!
if key.OAuthExpiry.Before(now) && !key.OAuthExpiry.IsZero() {
var oauthConfig OAuth2Config
switch key.LoginType {
case database.LoginTypeGithub:
oauthConfig = oauth.Github
default:
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("unexpected authentication type %q", key.LoginType),
})
return
}
// If it is, let's refresh it from the provided config!
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: key.OIDCAccessToken,
RefreshToken: key.OIDCRefreshToken,
Expiry: key.OIDCExpiry,
AccessToken: key.OAuthAccessToken,
RefreshToken: key.OAuthRefreshToken,
Expiry: key.OAuthExpiry,
}).Token()
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
@ -114,9 +124,9 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
})
return
}
key.OIDCAccessToken = token.AccessToken
key.OIDCRefreshToken = token.RefreshToken
key.OIDCExpiry = token.Expiry
key.OAuthAccessToken = token.AccessToken
key.OAuthRefreshToken = token.RefreshToken
key.OAuthExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
}
@ -136,21 +146,20 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle
changed = true
}
// Only update the ExpiresAt once an hour to prevent database spam.
// We extend the ExpiresAt to reduce reauthentication.
// We extend the ExpiresAt to reduce re-authentication.
apiKeyLifetime := 24 * time.Hour
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
if changed {
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
ID: key.ID,
ExpiresAt: key.ExpiresAt,
LastUsed: key.LastUsed,
OIDCAccessToken: key.OIDCAccessToken,
OIDCRefreshToken: key.OIDCRefreshToken,
OIDCExpiry: key.OIDCExpiry,
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
OAuthAccessToken: key.OAuthAccessToken,
OAuthRefreshToken: key.OAuthRefreshToken,
OAuthExpiry: key.OAuthExpiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{

View File

@ -189,7 +189,6 @@ func TestAPIKey(t *testing.T) {
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
@ -207,7 +206,6 @@ func TestAPIKey(t *testing.T) {
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
@ -277,7 +275,7 @@ func TestAPIKey(t *testing.T) {
require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCNotExpired", func(t *testing.T) {
t.Run("OAuthNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
@ -294,7 +292,7 @@ func TestAPIKey(t *testing.T) {
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LoginType: database.LoginTypeGithub,
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
@ -311,7 +309,7 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCRefresh", func(t *testing.T) {
t.Run("OAuthRefresh", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
@ -328,9 +326,9 @@ func TestAPIKey(t *testing.T) {
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LoginType: database.LoginTypeGithub,
LastUsed: database.Now(),
OIDCExpiry: database.Now().AddDate(0, 0, -1),
OAuthExpiry: database.Now().AddDate(0, 0, -1),
})
require.NoError(t, err)
token := &oauth2.Token{
@ -338,11 +336,11 @@ func TestAPIKey(t *testing.T) {
RefreshToken: "moo",
Expiry: database.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKey(db, &oauth2Config{
tokenSource: &oauth2TokenSource{
token: func() (*oauth2.Token, error) {
httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{
Github: &oauth2Config{
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
return token, nil
},
}),
},
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
@ -354,22 +352,28 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken)
require.Equal(t, token.AccessToken, gotAPIKey.OAuthAccessToken)
})
}
type oauth2Config struct {
tokenSource *oauth2TokenSource
tokenSource oauth2TokenSource
}
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}
type oauth2TokenSource struct {
token func() (*oauth2.Token, error)
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
return ""
}
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
return o.token()
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{}, nil
}
type oauth2TokenSource func() (*oauth2.Token, error)
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
return o()
}

132
coderd/httpmw/oauth2.go Normal file
View File

@ -0,0 +1,132 @@
package httpmw
import (
"context"
"fmt"
"net/http"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/cryptorand"
)
const (
oauth2StateCookieName = "oauth_state"
oauth2RedirectCookieName = "oauth_redirect"
)
type oauth2StateKey struct{}
type OAuth2State struct {
Token *oauth2.Token
Redirect string
}
// OAuth2Config exposes a subset of *oauth2.Config functions for easier testing.
// *oauth2.Config should be used instead of implementing this in production.
type OAuth2Config interface {
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
}
// OAuth2 returns the state from an oauth request.
func OAuth2(r *http.Request) OAuth2State {
oauth, ok := r.Context().Value(oauth2StateKey{}).(OAuth2State)
if !ok {
panic("developer error: oauth middleware not provided")
}
return oauth
}
// ExtractOAuth2 is a middleware for automatically redirecting to OAuth
// URLs, and handling the exchange inbound. Any route that does not have
// a "code" URL parameter will be redirected.
func ExtractOAuth2(config OAuth2Config) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if config == nil {
httpapi.Write(rw, http.StatusPreconditionRequired, httpapi.Response{
Message: fmt.Sprintf("The oauth2 method requested is not configured!"),
})
return
}
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" {
// If the code isn't provided, we'll redirect!
state, err := cryptorand.String(32)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("generate state string: %s", err),
})
return
}
http.SetCookie(rw, &http.Cookie{
Name: oauth2StateCookieName,
Value: state,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
})
// Redirect must always be specified, otherwise
// an old redirect could apply!
http.SetCookie(rw, &http.Cookie{
Name: oauth2RedirectCookieName,
Value: r.URL.Query().Get("redirect"),
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
})
http.Redirect(rw, r, config.AuthCodeURL(state, oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
return
}
if state == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "state must be provided",
})
return
}
stateCookie, err := r.Cookie(oauth2StateCookieName)
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("%q cookie must be provided", oauth2StateCookieName),
})
return
}
if stateCookie.Value != state {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "state mismatched",
})
return
}
var redirect string
stateRedirect, err := r.Cookie(oauth2RedirectCookieName)
if err == nil {
redirect = stateRedirect.Value
}
oauthToken, err := config.Exchange(r.Context(), code)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("exchange oauth code: %s", err),
})
return
}
ctx := context.WithValue(r.Context(), oauth2StateKey{}, OAuth2State{
Token: oauthToken,
Redirect: redirect,
})
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,98 @@
package httpmw_test
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/httpmw"
)
type testOAuth2Provider struct {
}
func (*testOAuth2Provider) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "?state=" + url.QueryEscape(state)
}
func (*testOAuth2Provider) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "hello",
}, nil
}
func (*testOAuth2Provider) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
return nil
}
func TestOAuth2(t *testing.T) {
t.Parallel()
t.Run("NotSetup", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(nil)(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusPreconditionRequired, res.Result().StatusCode)
})
t.Run("RedirectWithoutCode", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?redirect="+url.QueryEscape("/dashboard"), nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
location := res.Header().Get("Location")
if !assert.NotEmpty(t, location) {
return
}
require.Len(t, res.Result().Cookies(), 2)
cookie := res.Result().Cookies()[1]
require.Equal(t, "/dashboard", cookie.Value)
})
t.Run("NoState", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusBadRequest, res.Result().StatusCode)
})
t.Run("NoStateCookie", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("MismatchedState", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=something&state=test", nil)
req.AddCookie(&http.Cookie{
Name: "oauth_state",
Value: "mismatch",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(nil).ServeHTTP(res, req)
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
})
t.Run("ExchangeCodeAndState", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest("GET", "/?code=test&state=something", nil)
req.AddCookie(&http.Cookie{
Name: "oauth_state",
Value: "something",
})
req.AddCookie(&http.Cookie{
Name: "oauth_redirect",
Value: "/dashboard",
})
res := httptest.NewRecorder()
httpmw.ExtractOAuth2(&testOAuth2Provider{})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
state := httpmw.OAuth2(r)
require.Equal(t, "/dashboard", state.Redirect)
})).ServeHTTP(res, req)
})
}

View File

@ -41,7 +41,7 @@ func TestOrganizationParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestTemplateParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestTemplateVersionParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestWorkspaceBuildParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),

View File

@ -40,7 +40,7 @@ func TestWorkspaceParam(t *testing.T) {
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
LoginType: database.LoginTypePassword,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),