feat: pass access_token to coder_git_auth resource (#6713)

This allows template authors to leverage git auth to perform
custom actions, like clone repositories.
This commit is contained in:
Kyle Carberry
2023-03-22 14:37:08 -05:00
committed by GitHub
parent 79ae7cd639
commit df31636e72
20 changed files with 647 additions and 479 deletions

View File

@ -831,10 +831,6 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
mux := drpcmux.New()
gitAuthProviders := make([]string, 0, len(api.GitAuthConfigs))
for _, cfg := range api.GitAuthConfigs {
gitAuthProviders = append(gitAuthProviders, cfg.ID)
}
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
ID: daemon.ID,
@ -842,7 +838,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
GitAuthProviders: gitAuthProviders,
GitAuthConfigs: api.GitAuthConfigs,
Telemetry: api.Telemetry,
Tags: tags,
QuotaCommitter: &api.QuotaCommitter,

View File

@ -1,13 +1,17 @@
package gitauth
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
@ -34,6 +38,77 @@ type Config struct {
ValidateURL string
}
// RefreshToken automatically refreshes the token if expired and permitted.
// It returns the token and a bool indicating if the token was refreshed.
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
// If the token is expired and refresh is disabled, we prompt
// the user to authenticate again.
if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(database.Now()) {
return gitAuthLink, false, nil
}
token, err := c.TokenSource(ctx, &oauth2.Token{
AccessToken: gitAuthLink.OAuthAccessToken,
RefreshToken: gitAuthLink.OAuthRefreshToken,
Expiry: gitAuthLink.OAuthExpiry,
}).Token()
if err != nil {
// Even if the token fails to be obtained, we still return false because
// we aren't trying to surface an error, we're just trying to obtain a valid token.
return gitAuthLink, false, nil
}
if c.ValidateURL != "" {
valid, err := c.ValidateToken(ctx, token.AccessToken)
if err != nil {
return gitAuthLink, false, xerrors.Errorf("validate git auth token: %w", err)
}
if !valid {
// The token is no longer valid!
return gitAuthLink, false, nil
}
}
if token.AccessToken != gitAuthLink.OAuthAccessToken {
// Update it
gitAuthLink, err = db.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
ProviderID: c.ID,
UserID: gitAuthLink.UserID,
UpdatedAt: database.Now(),
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
})
if err != nil {
return gitAuthLink, false, xerrors.Errorf("update git auth link: %w", err)
}
}
return gitAuthLink, true, nil
}
// ValidateToken ensures the Git token provided is valid!
func (c *Config) ValidateToken(ctx context.Context, token string) (bool, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.ValidateURL, nil)
if err != nil {
return false, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusUnauthorized {
// The token is no longer valid!
return false, nil
}
if res.StatusCode != http.StatusOK {
data, _ := io.ReadAll(res.Body)
return false, xerrors.Errorf("status %d: body: %s", res.StatusCode, data)
}
return true, nil
}
// ConvertConfig converts the SDK configuration entry format
// to the parsed and ready-to-consume in coderd provider type.
func ConvertConfig(entries []codersdk.GitAuthConfig, accessURL *url.URL) ([]*Config, error) {

View File

@ -1,15 +1,122 @@
package gitauth_test
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
func TestRefreshToken(t *testing.T) {
t.Parallel()
t.Run("FalseIfNoRefresh", func(t *testing.T) {
t.Parallel()
config := &gitauth.Config{
NoRefresh: true,
}
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
OAuthExpiry: time.Time{},
})
require.NoError(t, err)
require.False(t, refreshed)
})
t.Run("FalseIfTokenSourceFails", func(t *testing.T) {
t.Parallel()
config := &gitauth.Config{
OAuth2Config: &testutil.OAuth2Config{
TokenSourceFunc: func() (*oauth2.Token, error) {
return nil, xerrors.New("failure")
},
},
}
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
require.NoError(t, err)
require.False(t, refreshed)
})
t.Run("ValidateServerError", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Failure"))
}))
config := &gitauth.Config{
OAuth2Config: &testutil.OAuth2Config{},
ValidateURL: srv.URL,
}
_, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
require.ErrorContains(t, err, "Failure")
})
t.Run("ValidateFailure", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("Not permitted"))
}))
config := &gitauth.Config{
OAuth2Config: &testutil.OAuth2Config{},
ValidateURL: srv.URL,
}
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
require.NoError(t, err)
require.False(t, refreshed)
})
t.Run("ValidateNoUpdate", func(t *testing.T) {
t.Parallel()
validated := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
close(validated)
}))
accessToken := "testing"
config := &gitauth.Config{
OAuth2Config: &testutil.OAuth2Config{
Token: &oauth2.Token{
AccessToken: accessToken,
},
},
ValidateURL: srv.URL,
}
_, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
OAuthAccessToken: accessToken,
})
require.NoError(t, err)
require.True(t, valid)
<-validated
})
t.Run("Updates", func(t *testing.T) {
t.Parallel()
config := &gitauth.Config{
ID: "test",
OAuth2Config: &testutil.OAuth2Config{
Token: &oauth2.Token{
AccessToken: "updated",
},
},
}
db := dbfake.New()
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
ProviderID: config.ID,
OAuthAccessToken: "initial",
})
_, valid, err := config.RefreshToken(context.Background(), db, link)
require.NoError(t, err)
require.True(t, valid)
})
}
func TestConvertYAML(t *testing.T) {
t.Parallel()
for _, tc := range []struct {

View File

@ -22,6 +22,7 @@ import (
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/testutil"
)
func randomAPIKeyParts() (id string, secret string) {
@ -462,10 +463,8 @@ func TestAPIKey(t *testing.T) {
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
OAuth2Configs: &httpmw.OAuth2Configs{
Github: &oauth2Config{
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
return oauthToken, nil
}),
Github: &testutil.OAuth2Config{
Token: oauthToken,
},
},
RedirectToLogin: false,
@ -597,25 +596,3 @@ func TestAPIKey(t *testing.T) {
require.Equal(t, sentAPIKey.LoginType, gotAPIKey.LoginType)
})
}
type oauth2Config struct {
tokenSource oauth2TokenSource
}
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
return ""
}
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{}, nil
}
type oauth2TokenSource func() (*oauth2.Token, error)
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
return o()
}

View File

@ -28,11 +28,11 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/util/slice"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner"
"github.com/coder/coder/provisionerd/proto"
@ -50,7 +50,7 @@ type Server struct {
ID uuid.UUID
Logger slog.Logger
Provisioners []database.ProvisionerType
GitAuthProviders []string
GitAuthConfigs []*gitauth.Config
Tags json.RawMessage
Database database.Store
Pubsub database.Pubsub
@ -210,6 +210,48 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err))
}
gitAuthProviders := []*sdkproto.GitAuthProvider{}
for _, p := range templateVersion.GitAuthProviders {
link, err := server.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
ProviderID: p,
UserID: owner.ID,
})
if errors.Is(err, sql.ErrNoRows) {
continue
}
if err != nil {
return nil, failJob(fmt.Sprintf("acquire git auth link: %s", err))
}
var config *gitauth.Config
for _, c := range server.GitAuthConfigs {
if c.ID != p {
continue
}
config = c
break
}
// We weren't able to find a matching config for the ID!
if config == nil {
server.Logger.Warn(ctx, "workspace build job is missing git provider",
slog.F("git_provider_id", p),
slog.F("template_version_id", templateVersion.ID),
slog.F("workspace_id", workspaceBuild.WorkspaceID))
continue
}
link, valid, err := config.RefreshToken(ctx, server.Database, link)
if err != nil {
return nil, failJob(fmt.Sprintf("refresh git auth link %q: %s", p, err))
}
if !valid {
continue
}
gitAuthProviders = append(gitAuthProviders, &sdkproto.GitAuthProvider{
Id: p,
AccessToken: link.OAuthAccessToken,
})
}
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
WorkspaceBuildId: workspaceBuild.ID.String(),
@ -218,6 +260,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
ParameterValues: protoParameters,
RichParameterValues: convertRichParameterValues(workspaceBuildParameters),
VariableValues: asVariableValues(templateVariables),
GitAuthProviders: gitAuthProviders,
Metadata: &sdkproto.Provision_Metadata{
CoderUrl: server.AccessURL.String(),
WorkspaceTransition: transition,
@ -857,7 +900,14 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
var completedError sql.NullString
for _, gitAuthProvider := range jobType.TemplateImport.GitAuthProviders {
if !slice.Contains(server.GitAuthProviders, gitAuthProvider) {
contains := false
for _, configuredProvider := range server.GitAuthConfigs {
if configuredProvider.ID == gitAuthProvider {
contains = true
break
}
}
if !contains {
completedError = sql.NullString{
String: fmt.Sprintf("git auth provider %q is not configured", gitAuthProvider),
Valid: true,

View File

@ -12,6 +12,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/testutil"
)
func TestObtainOIDCAccessToken(t *testing.T) {
@ -45,11 +46,9 @@ func TestObtainOIDCAccessToken(t *testing.T) {
LoginType: database.LoginTypeOIDC,
OAuthExpiry: database.Now().Add(-time.Hour),
})
_, err := obtainOIDCAccessToken(ctx, db, &oauth2Config{
tokenSource: func() (*oauth2.Token, error) {
return &oauth2.Token{
AccessToken: "token",
}, nil
_, err := obtainOIDCAccessToken(ctx, db, &testutil.OAuth2Config{
Token: &oauth2.Token{
AccessToken: "token",
},
}, user.ID)
require.NoError(t, err)
@ -61,25 +60,3 @@ func TestObtainOIDCAccessToken(t *testing.T) {
require.Equal(t, "token", link.OAuthAccessToken)
})
}
type oauth2Config struct {
tokenSource oauth2TokenSource
}
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
return ""
}
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return &oauth2.Token{}, nil
}
type oauth2TokenSource func() (*oauth2.Token, error)
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
return o()
}

View File

@ -18,6 +18,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/telemetry"
@ -98,6 +99,11 @@ func TestAcquireJob(t *testing.T) {
t.Run("WorkspaceBuildJob", func(t *testing.T) {
t.Parallel()
srv := setup(t, false)
gitAuthProvider := "github"
srv.GitAuthConfigs = []*gitauth.Config{{
ID: gitAuthProvider,
OAuth2Config: &testutil.OAuth2Config{},
}}
ctx := context.Background()
user := dbgen.User(t, srv.Database, database.User{})
@ -107,6 +113,10 @@ func TestAcquireJob(t *testing.T) {
OAuthExpiry: database.Now().Add(time.Hour),
OAuthAccessToken: "access-token",
})
dbgen.GitAuthLink(t, srv.Database, database.GitAuthLink{
ProviderID: gitAuthProvider,
UserID: user.ID,
})
template := dbgen.Template(t, srv.Database, database.Template{
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
@ -120,6 +130,12 @@ func TestAcquireJob(t *testing.T) {
},
JobID: uuid.New(),
})
err := srv.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
JobID: version.JobID,
GitAuthProviders: []string{gitAuthProvider},
UpdatedAt: database.Now(),
})
require.NoError(t, err)
// Import version job
_ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{
ID: version.JobID,
@ -214,6 +230,10 @@ func TestAcquireJob(t *testing.T) {
Value: "second_value",
},
},
GitAuthProviders: []*sdkproto.GitAuthProvider{{
Id: gitAuthProvider,
AccessToken: "access_token",
}},
Metadata: &sdkproto.Provision_Metadata{
CoderUrl: srv.AccessURL.String(),
WorkspaceTransition: sdkproto.WorkspaceTransition_START,
@ -795,7 +815,9 @@ func TestCompleteJob(t *testing.T) {
job, err = srv.Database.GetProvisionerJobByID(ctx, job.ID)
require.NoError(t, err)
require.Contains(t, job.Error.String, `git auth provider "github" is not configured`)
srv.GitAuthProviders = []string{"github"}
srv.GitAuthConfigs = []*gitauth.Config{{
ID: "github",
}}
completeJob()
job, err = srv.Database.GetProvisionerJobByID(ctx, job.ID)
require.NoError(t, err)
@ -930,8 +952,7 @@ func TestCompleteJob(t *testing.T) {
Valid: true,
}
}
workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{
ID: uuid.New(),
workspace := dbgen.Workspace(t, srv.Database, database.Workspace{
TemplateID: template.ID,
Ttl: workspaceTTL,
})
@ -942,26 +963,19 @@ func TestCompleteJob(t *testing.T) {
},
JobID: uuid.New(),
})
require.NoError(t, err)
build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{
ID: uuid.New(),
build := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: version.ID,
Transition: c.transition,
Reason: database.BuildReasonInitiator,
})
require.NoError(t, err)
job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(),
FileID: file.ID,
Provisioner: database.ProvisionerTypeEcho,
Type: database.ProvisionerJobTypeWorkspaceBuild,
StorageMethod: database.ProvisionerStorageMethodFile,
job := dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
})),
})
require.NoError(t, err)
_, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
WorkerID: uuid.NullUUID{
UUID: srv.ID,
@ -1022,7 +1036,6 @@ func TestCompleteJob(t *testing.T) {
})
}
})
t.Run("TemplateDryRun", func(t *testing.T) {
t.Parallel()
srv := setup(t, false)

View File

@ -291,7 +291,7 @@ func (api *API) templateVersionGitAuth(rw http.ResponseWriter, r *http.Request)
return
}
_, updated, err := refreshGitToken(ctx, api.Database, apiKey.UserID, config, authLink)
_, updated, err := config.RefreshToken(ctx, api.Database, authLink)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to refresh git auth token.",

View File

@ -462,7 +462,7 @@ func TestTemplateVersionsGitAuth(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
GitAuthConfigs: []*gitauth.Config{{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ID: "github",
Regex: regexp.MustCompile(`github\.com`),
Type: codersdk.GitProviderGitHub,

View File

@ -6,10 +6,8 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
@ -27,46 +25,6 @@ import (
"github.com/coder/coder/testutil"
)
type oauth2Config struct {
token *oauth2.Token
}
func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
return "/?state=" + url.QueryEscape(state)
}
func (o *oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
if o.token != nil {
return o.token, nil
}
return &oauth2.Token{
AccessToken: "token",
RefreshToken: "refresh",
Expiry: database.Now().Add(time.Hour),
}, nil
}
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
return &oauth2TokenSource{
token: o.token,
}
}
type oauth2TokenSource struct {
token *oauth2.Token
}
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
if o.token != nil {
return o.token, nil
}
return &oauth2.Token{
AccessToken: "token",
RefreshToken: "refresh",
Expiry: database.Now().Add(time.Hour),
}, nil
}
func TestUserAuthMethods(t *testing.T) {
t.Parallel()
t.Run("Password", func(t *testing.T) {
@ -110,7 +68,7 @@ func TestUserOAuth2Github(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
State: &stateActive,
@ -137,7 +95,7 @@ func TestUserOAuth2Github(t *testing.T) {
GithubOAuth2Config: &coderd.GithubOAuth2Config{
AllowOrganizations: []string{"coder"},
AllowTeams: []coderd.GithubOAuth2Team{{"another", "something"}, {"coder", "frontend"}},
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
State: &stateActive,
@ -171,7 +129,7 @@ func TestUserOAuth2Github(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
AllowOrganizations: []string{"coder"},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
@ -210,7 +168,7 @@ func TestUserOAuth2Github(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
AllowOrganizations: []string{"coder"},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
@ -247,7 +205,7 @@ func TestUserOAuth2Github(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
AllowOrganizations: []string{"coder"},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
@ -290,7 +248,7 @@ func TestUserOAuth2Github(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
GithubOAuth2Config: &coderd.GithubOAuth2Config{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
AllowOrganizations: []string{"coder"},
AllowSignups: true,
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
@ -343,7 +301,7 @@ func TestUserOAuth2Github(t *testing.T) {
AllowSignups: true,
AllowOrganizations: []string{"coder"},
AllowTeams: []coderd.GithubOAuth2Team{{"coder", "frontend"}},
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
State: &stateActive,
@ -387,7 +345,7 @@ func TestUserOAuth2Github(t *testing.T) {
AllowSignups: true,
AllowOrganizations: []string{"coder", "nil"},
AllowTeams: []coderd.GithubOAuth2Team{{"coder", "backend"}},
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{
{
@ -439,7 +397,7 @@ func TestUserOAuth2Github(t *testing.T) {
AllowSignups: true,
AllowOrganizations: []string{"coder", "nil"},
AllowTeams: []coderd.GithubOAuth2Team{{"nil", "null"}},
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{
{
@ -490,7 +448,7 @@ func TestUserOAuth2Github(t *testing.T) {
GithubOAuth2Config: &coderd.GithubOAuth2Config{
AllowSignups: true,
AllowEveryone: true,
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{}, nil
},
@ -529,7 +487,7 @@ func TestUserOAuth2Github(t *testing.T) {
AllowSignups: true,
AllowOrganizations: []string{"coder"},
AllowTeams: []coderd.GithubOAuth2Team{{"coder", "frontend"}},
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
return []*github.Membership{{
State: &statePending,
@ -830,7 +788,7 @@ func TestUserOIDC(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: &coderd.OIDCConfig{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
},
})
numLogs := len(auditor.AuditLogs)
@ -854,8 +812,8 @@ func TestUserOIDC(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
Auditor: auditor,
OIDCConfig: &coderd.OIDCConfig{
OAuth2Config: &oauth2Config{
token: (&oauth2.Token{
OAuth2Config: &testutil.OAuth2Config{
Token: (&oauth2.Token{
AccessToken: "token",
}).WithExtra(map[string]interface{}{
"id_token": "invalid",

View File

@ -7,7 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
@ -21,7 +20,6 @@ import (
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"golang.org/x/mod/semver"
"golang.org/x/oauth2"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"tailscale.com/tailcfg"
@ -1327,7 +1325,7 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request)
continue
}
if gitAuthConfig.ValidateURL != "" {
valid, err := validateGitToken(ctx, gitAuthConfig.ValidateURL, gitAuthLink.OAuthAccessToken)
valid, err := gitAuthConfig.ValidateToken(ctx, gitAuthLink.OAuthAccessToken)
if err != nil {
api.Logger.Warn(ctx, "failed to validate git auth token",
slog.F("workspace_owner_id", workspace.OwnerID.String()),
@ -1373,7 +1371,7 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request)
return
}
gitAuthLink, updated, err := refreshGitToken(ctx, api.Database, workspace.OwnerID, gitAuthConfig, gitAuthLink)
gitAuthLink, updated, err := gitAuthConfig.RefreshToken(ctx, api.Database, gitAuthLink)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to refresh git auth token.",
@ -1390,74 +1388,6 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request)
httpapi.Write(ctx, rw, http.StatusOK, formatGitAuthAccessToken(gitAuthConfig.Type, gitAuthLink.OAuthAccessToken))
}
func refreshGitToken(ctx context.Context, db database.Store, owner uuid.UUID, gitAuthConfig *gitauth.Config, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
// If the token is expired and refresh is disabled, we prompt
// the user to authenticate again.
if gitAuthConfig.NoRefresh && gitAuthLink.OAuthExpiry.Before(database.Now()) {
return gitAuthLink, false, nil
}
token, err := gitAuthConfig.TokenSource(ctx, &oauth2.Token{
AccessToken: gitAuthLink.OAuthAccessToken,
RefreshToken: gitAuthLink.OAuthRefreshToken,
Expiry: gitAuthLink.OAuthExpiry,
}).Token()
if err != nil {
return gitAuthLink, false, nil
}
if gitAuthConfig.ValidateURL != "" {
valid, err := validateGitToken(ctx, gitAuthConfig.ValidateURL, token.AccessToken)
if err != nil {
return gitAuthLink, false, xerrors.Errorf("validate git auth token: %w", err)
}
if !valid {
// The token is no longer valid!
return gitAuthLink, false, nil
}
}
if token.AccessToken != gitAuthLink.OAuthAccessToken {
// Update it
gitAuthLink, err = db.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
ProviderID: gitAuthConfig.ID,
UserID: owner,
UpdatedAt: database.Now(),
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
})
if err != nil {
return gitAuthLink, false, xerrors.Errorf("update git auth link: %w", err)
}
}
return gitAuthLink, true, nil
}
// validateGitToken ensures the git token provided is valid
// against the provided URL.
func validateGitToken(ctx context.Context, validateURL, token string) (bool, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, validateURL, nil)
if err != nil {
return false, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := http.DefaultClient.Do(req)
if err != nil {
return false, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusUnauthorized {
// The token is no longer valid!
return false, nil
}
if res.StatusCode != http.StatusOK {
data, _ := io.ReadAll(res.Body)
return false, xerrors.Errorf("status %d: body: %s", res.StatusCode, data)
}
return true, nil
}
// Provider types have different username/password formats.
func formatGitAuthAccessToken(typ codersdk.GitProvider, token string) agentsdk.GitAuthResponse {
var resp agentsdk.GitAuthResponse

View File

@ -786,7 +786,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
GitAuthConfigs: []*gitauth.Config{{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ID: "github",
Regex: regexp.MustCompile(`github\.com`),
Type: codersdk.GitProviderGitHub,
@ -830,7 +830,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
GitAuthConfigs: []*gitauth.Config{{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ID: "github",
Regex: regexp.MustCompile(`github\.com`),
Type: codersdk.GitProviderGitHub,
@ -844,7 +844,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
GitAuthConfigs: []*gitauth.Config{{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ID: "github",
Regex: regexp.MustCompile(`github\.com`),
Type: codersdk.GitProviderGitHub,
@ -872,7 +872,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
IncludeProvisionerDaemon: true,
GitAuthConfigs: []*gitauth.Config{{
ValidateURL: srv.URL,
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ID: "github",
Regex: regexp.MustCompile(`github\.com`),
Type: codersdk.GitProviderGitHub,
@ -923,8 +923,8 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
GitAuthConfigs: []*gitauth.Config{{
OAuth2Config: &oauth2Config{
token: &oauth2.Token{
OAuth2Config: &testutil.OAuth2Config{
Token: &oauth2.Token{
AccessToken: "token",
RefreshToken: "something",
Expiry: database.Now().Add(-time.Hour),
@ -973,7 +973,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
GitAuthConfigs: []*gitauth.Config{{
OAuth2Config: &oauth2Config{},
OAuth2Config: &testutil.OAuth2Config{},
ID: "github",
Regex: regexp.MustCompile(`github\.com`),
Type: codersdk.GitProviderGitHub,
@ -1011,7 +1011,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
resp := coderdtest.RequestGitAuthCallback(t, "github", client)
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
token = <-tokenChan
require.Equal(t, "token", token.Username)
require.Equal(t, "access_token", token.Username)
token, err = agentClient.GitAuth(context.Background(), "github.com/asd/asd", false)
require.NoError(t, err)