mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
feat: pass access_token
to coder_git_auth
resource (#6713)
This allows template authors to leverage git auth to perform custom actions, like clone repositories.
This commit is contained in:
@ -831,10 +831,6 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
|
||||
|
||||
mux := drpcmux.New()
|
||||
|
||||
gitAuthProviders := make([]string, 0, len(api.GitAuthConfigs))
|
||||
for _, cfg := range api.GitAuthConfigs {
|
||||
gitAuthProviders = append(gitAuthProviders, cfg.ID)
|
||||
}
|
||||
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
|
||||
AccessURL: api.AccessURL,
|
||||
ID: daemon.ID,
|
||||
@ -842,7 +838,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
|
||||
Database: api.Database,
|
||||
Pubsub: api.Pubsub,
|
||||
Provisioners: daemon.Provisioners,
|
||||
GitAuthProviders: gitAuthProviders,
|
||||
GitAuthConfigs: api.GitAuthConfigs,
|
||||
Telemetry: api.Telemetry,
|
||||
Tags: tags,
|
||||
QuotaCommitter: &api.QuotaCommitter,
|
||||
|
@ -1,13 +1,17 @@
|
||||
package gitauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
@ -34,6 +38,77 @@ type Config struct {
|
||||
ValidateURL string
|
||||
}
|
||||
|
||||
// RefreshToken automatically refreshes the token if expired and permitted.
|
||||
// It returns the token and a bool indicating if the token was refreshed.
|
||||
func (c *Config) RefreshToken(ctx context.Context, db database.Store, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
|
||||
// If the token is expired and refresh is disabled, we prompt
|
||||
// the user to authenticate again.
|
||||
if c.NoRefresh && gitAuthLink.OAuthExpiry.Before(database.Now()) {
|
||||
return gitAuthLink, false, nil
|
||||
}
|
||||
|
||||
token, err := c.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: gitAuthLink.OAuthAccessToken,
|
||||
RefreshToken: gitAuthLink.OAuthRefreshToken,
|
||||
Expiry: gitAuthLink.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
// Even if the token fails to be obtained, we still return false because
|
||||
// we aren't trying to surface an error, we're just trying to obtain a valid token.
|
||||
return gitAuthLink, false, nil
|
||||
}
|
||||
|
||||
if c.ValidateURL != "" {
|
||||
valid, err := c.ValidateToken(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return gitAuthLink, false, xerrors.Errorf("validate git auth token: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
// The token is no longer valid!
|
||||
return gitAuthLink, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if token.AccessToken != gitAuthLink.OAuthAccessToken {
|
||||
// Update it
|
||||
gitAuthLink, err = db.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
|
||||
ProviderID: c.ID,
|
||||
UserID: gitAuthLink.UserID,
|
||||
UpdatedAt: database.Now(),
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
})
|
||||
if err != nil {
|
||||
return gitAuthLink, false, xerrors.Errorf("update git auth link: %w", err)
|
||||
}
|
||||
}
|
||||
return gitAuthLink, true, nil
|
||||
}
|
||||
|
||||
// ValidateToken ensures the Git token provided is valid!
|
||||
func (c *Config) ValidateToken(ctx context.Context, token string) (bool, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.ValidateURL, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode == http.StatusUnauthorized {
|
||||
// The token is no longer valid!
|
||||
return false, nil
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
data, _ := io.ReadAll(res.Body)
|
||||
return false, xerrors.Errorf("status %d: body: %s", res.StatusCode, data)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ConvertConfig converts the SDK configuration entry format
|
||||
// to the parsed and ready-to-consume in coderd provider type.
|
||||
func ConvertConfig(entries []codersdk.GitAuthConfig, accessURL *url.URL) ([]*Config, error) {
|
||||
|
@ -1,15 +1,122 @@
|
||||
package gitauth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/database/dbgen"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestRefreshToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("FalseIfNoRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
NoRefresh: true,
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthExpiry: time.Time{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, refreshed)
|
||||
})
|
||||
t.Run("FalseIfTokenSourceFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
TokenSourceFunc: func() (*oauth2.Token, error) {
|
||||
return nil, xerrors.New("failure")
|
||||
},
|
||||
},
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, refreshed)
|
||||
})
|
||||
t.Run("ValidateServerError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Failure"))
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, _, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
require.ErrorContains(t, err, "Failure")
|
||||
})
|
||||
t.Run("ValidateFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("Not permitted"))
|
||||
}))
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, refreshed, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{})
|
||||
require.NoError(t, err)
|
||||
require.False(t, refreshed)
|
||||
})
|
||||
t.Run("ValidateNoUpdate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
validated := make(chan struct{})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
close(validated)
|
||||
}))
|
||||
accessToken := "testing"
|
||||
config := &gitauth.Config{
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
},
|
||||
},
|
||||
ValidateURL: srv.URL,
|
||||
}
|
||||
_, valid, err := config.RefreshToken(context.Background(), nil, database.GitAuthLink{
|
||||
OAuthAccessToken: accessToken,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, valid)
|
||||
<-validated
|
||||
})
|
||||
t.Run("Updates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := &gitauth.Config{
|
||||
ID: "test",
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "updated",
|
||||
},
|
||||
},
|
||||
}
|
||||
db := dbfake.New()
|
||||
link := dbgen.GitAuthLink(t, db, database.GitAuthLink{
|
||||
ProviderID: config.ID,
|
||||
OAuthAccessToken: "initial",
|
||||
})
|
||||
_, valid, err := config.RefreshToken(context.Background(), db, link)
|
||||
require.NoError(t, err)
|
||||
require.True(t, valid)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConvertYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tc := range []struct {
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func randomAPIKeyParts() (id string, secret string) {
|
||||
@ -462,10 +463,8 @@ func TestAPIKey(t *testing.T) {
|
||||
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
OAuth2Configs: &httpmw.OAuth2Configs{
|
||||
Github: &oauth2Config{
|
||||
tokenSource: oauth2TokenSource(func() (*oauth2.Token, error) {
|
||||
return oauthToken, nil
|
||||
}),
|
||||
Github: &testutil.OAuth2Config{
|
||||
Token: oauthToken,
|
||||
},
|
||||
},
|
||||
RedirectToLogin: false,
|
||||
@ -597,25 +596,3 @@ func TestAPIKey(t *testing.T) {
|
||||
require.Equal(t, sentAPIKey.LoginType, gotAPIKey.LoginType)
|
||||
})
|
||||
}
|
||||
|
||||
type oauth2Config struct {
|
||||
tokenSource oauth2TokenSource
|
||||
}
|
||||
|
||||
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return o.tokenSource
|
||||
}
|
||||
|
||||
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{}, nil
|
||||
}
|
||||
|
||||
type oauth2TokenSource func() (*oauth2.Token, error)
|
||||
|
||||
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
return o()
|
||||
}
|
||||
|
@ -28,11 +28,11 @@ import (
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/parameter"
|
||||
"github.com/coder/coder/coderd/schedule"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/util/slice"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisioner"
|
||||
"github.com/coder/coder/provisionerd/proto"
|
||||
@ -50,7 +50,7 @@ type Server struct {
|
||||
ID uuid.UUID
|
||||
Logger slog.Logger
|
||||
Provisioners []database.ProvisionerType
|
||||
GitAuthProviders []string
|
||||
GitAuthConfigs []*gitauth.Config
|
||||
Tags json.RawMessage
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
@ -210,6 +210,48 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
||||
return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err))
|
||||
}
|
||||
|
||||
gitAuthProviders := []*sdkproto.GitAuthProvider{}
|
||||
for _, p := range templateVersion.GitAuthProviders {
|
||||
link, err := server.Database.GetGitAuthLink(ctx, database.GetGitAuthLinkParams{
|
||||
ProviderID: p,
|
||||
UserID: owner.ID,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("acquire git auth link: %s", err))
|
||||
}
|
||||
var config *gitauth.Config
|
||||
for _, c := range server.GitAuthConfigs {
|
||||
if c.ID != p {
|
||||
continue
|
||||
}
|
||||
config = c
|
||||
break
|
||||
}
|
||||
// We weren't able to find a matching config for the ID!
|
||||
if config == nil {
|
||||
server.Logger.Warn(ctx, "workspace build job is missing git provider",
|
||||
slog.F("git_provider_id", p),
|
||||
slog.F("template_version_id", templateVersion.ID),
|
||||
slog.F("workspace_id", workspaceBuild.WorkspaceID))
|
||||
continue
|
||||
}
|
||||
|
||||
link, valid, err := config.RefreshToken(ctx, server.Database, link)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("refresh git auth link %q: %s", p, err))
|
||||
}
|
||||
if !valid {
|
||||
continue
|
||||
}
|
||||
gitAuthProviders = append(gitAuthProviders, &sdkproto.GitAuthProvider{
|
||||
Id: p,
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
})
|
||||
}
|
||||
|
||||
protoJob.Type = &proto.AcquiredJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
||||
WorkspaceBuildId: workspaceBuild.ID.String(),
|
||||
@ -218,6 +260,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
||||
ParameterValues: protoParameters,
|
||||
RichParameterValues: convertRichParameterValues(workspaceBuildParameters),
|
||||
VariableValues: asVariableValues(templateVariables),
|
||||
GitAuthProviders: gitAuthProviders,
|
||||
Metadata: &sdkproto.Provision_Metadata{
|
||||
CoderUrl: server.AccessURL.String(),
|
||||
WorkspaceTransition: transition,
|
||||
@ -857,7 +900,14 @@ func (server *Server) CompleteJob(ctx context.Context, completed *proto.Complete
|
||||
var completedError sql.NullString
|
||||
|
||||
for _, gitAuthProvider := range jobType.TemplateImport.GitAuthProviders {
|
||||
if !slice.Contains(server.GitAuthProviders, gitAuthProvider) {
|
||||
contains := false
|
||||
for _, configuredProvider := range server.GitAuthConfigs {
|
||||
if configuredProvider.ID == gitAuthProvider {
|
||||
contains = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !contains {
|
||||
completedError = sql.NullString{
|
||||
String: fmt.Sprintf("git auth provider %q is not configured", gitAuthProvider),
|
||||
Valid: true,
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/database/dbgen"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestObtainOIDCAccessToken(t *testing.T) {
|
||||
@ -45,11 +46,9 @@ func TestObtainOIDCAccessToken(t *testing.T) {
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
OAuthExpiry: database.Now().Add(-time.Hour),
|
||||
})
|
||||
_, err := obtainOIDCAccessToken(ctx, db, &oauth2Config{
|
||||
tokenSource: func() (*oauth2.Token, error) {
|
||||
return &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}, nil
|
||||
_, err := obtainOIDCAccessToken(ctx, db, &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
},
|
||||
}, user.ID)
|
||||
require.NoError(t, err)
|
||||
@ -61,25 +60,3 @@ func TestObtainOIDCAccessToken(t *testing.T) {
|
||||
require.Equal(t, "token", link.OAuthAccessToken)
|
||||
})
|
||||
}
|
||||
|
||||
type oauth2Config struct {
|
||||
tokenSource oauth2TokenSource
|
||||
}
|
||||
|
||||
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return o.tokenSource
|
||||
}
|
||||
|
||||
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return &oauth2.Token{}, nil
|
||||
}
|
||||
|
||||
type oauth2TokenSource func() (*oauth2.Token, error)
|
||||
|
||||
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
return o()
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/database/dbgen"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/coderd/schedule"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
@ -98,6 +99,11 @@ func TestAcquireJob(t *testing.T) {
|
||||
t.Run("WorkspaceBuildJob", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := setup(t, false)
|
||||
gitAuthProvider := "github"
|
||||
srv.GitAuthConfigs = []*gitauth.Config{{
|
||||
ID: gitAuthProvider,
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
}}
|
||||
ctx := context.Background()
|
||||
|
||||
user := dbgen.User(t, srv.Database, database.User{})
|
||||
@ -107,6 +113,10 @@ func TestAcquireJob(t *testing.T) {
|
||||
OAuthExpiry: database.Now().Add(time.Hour),
|
||||
OAuthAccessToken: "access-token",
|
||||
})
|
||||
dbgen.GitAuthLink(t, srv.Database, database.GitAuthLink{
|
||||
ProviderID: gitAuthProvider,
|
||||
UserID: user.ID,
|
||||
})
|
||||
template := dbgen.Template(t, srv.Database, database.Template{
|
||||
Name: "template",
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
@ -120,6 +130,12 @@ func TestAcquireJob(t *testing.T) {
|
||||
},
|
||||
JobID: uuid.New(),
|
||||
})
|
||||
err := srv.Database.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
|
||||
JobID: version.JobID,
|
||||
GitAuthProviders: []string{gitAuthProvider},
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Import version job
|
||||
_ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{
|
||||
ID: version.JobID,
|
||||
@ -214,6 +230,10 @@ func TestAcquireJob(t *testing.T) {
|
||||
Value: "second_value",
|
||||
},
|
||||
},
|
||||
GitAuthProviders: []*sdkproto.GitAuthProvider{{
|
||||
Id: gitAuthProvider,
|
||||
AccessToken: "access_token",
|
||||
}},
|
||||
Metadata: &sdkproto.Provision_Metadata{
|
||||
CoderUrl: srv.AccessURL.String(),
|
||||
WorkspaceTransition: sdkproto.WorkspaceTransition_START,
|
||||
@ -795,7 +815,9 @@ func TestCompleteJob(t *testing.T) {
|
||||
job, err = srv.Database.GetProvisionerJobByID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, job.Error.String, `git auth provider "github" is not configured`)
|
||||
srv.GitAuthProviders = []string{"github"}
|
||||
srv.GitAuthConfigs = []*gitauth.Config{{
|
||||
ID: "github",
|
||||
}}
|
||||
completeJob()
|
||||
job, err = srv.Database.GetProvisionerJobByID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
@ -930,8 +952,7 @@ func TestCompleteJob(t *testing.T) {
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
workspace, err := srv.Database.InsertWorkspace(ctx, database.InsertWorkspaceParams{
|
||||
ID: uuid.New(),
|
||||
workspace := dbgen.Workspace(t, srv.Database, database.Workspace{
|
||||
TemplateID: template.ID,
|
||||
Ttl: workspaceTTL,
|
||||
})
|
||||
@ -942,26 +963,19 @@ func TestCompleteJob(t *testing.T) {
|
||||
},
|
||||
JobID: uuid.New(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
build, err := srv.Database.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{
|
||||
ID: uuid.New(),
|
||||
build := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
TemplateVersionID: version.ID,
|
||||
Transition: c.transition,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
job, err := srv.Database.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
FileID: file.ID,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
job := dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{
|
||||
FileID: file.ID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: build.ID,
|
||||
})),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = srv.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: srv.ID,
|
||||
@ -1022,7 +1036,6 @@ func TestCompleteJob(t *testing.T) {
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TemplateDryRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := setup(t, false)
|
||||
|
@ -291,7 +291,7 @@ func (api *API) templateVersionGitAuth(rw http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
_, updated, err := refreshGitToken(ctx, api.Database, apiKey.UserID, config, authLink)
|
||||
_, updated, err := config.RefreshToken(ctx, api.Database, authLink)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to refresh git auth token.",
|
||||
|
@ -462,7 +462,7 @@ func TestTemplateVersionsGitAuth(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
GitAuthConfigs: []*gitauth.Config{{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ID: "github",
|
||||
Regex: regexp.MustCompile(`github\.com`),
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
|
@ -6,10 +6,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt"
|
||||
@ -27,46 +25,6 @@ import (
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
type oauth2Config struct {
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
func (*oauth2Config) AuthCodeURL(state string, _ ...oauth2.AuthCodeOption) string {
|
||||
return "/?state=" + url.QueryEscape(state)
|
||||
}
|
||||
|
||||
func (o *oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
if o.token != nil {
|
||||
return o.token, nil
|
||||
}
|
||||
return &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
RefreshToken: "refresh",
|
||||
Expiry: database.Now().Add(time.Hour),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
|
||||
return &oauth2TokenSource{
|
||||
token: o.token,
|
||||
}
|
||||
}
|
||||
|
||||
type oauth2TokenSource struct {
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
if o.token != nil {
|
||||
return o.token, nil
|
||||
}
|
||||
return &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
RefreshToken: "refresh",
|
||||
Expiry: database.Now().Add(time.Hour),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestUserAuthMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Password", func(t *testing.T) {
|
||||
@ -110,7 +68,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
State: &stateActive,
|
||||
@ -137,7 +95,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowTeams: []coderd.GithubOAuth2Team{{"another", "something"}, {"coder", "frontend"}},
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
State: &stateActive,
|
||||
@ -171,7 +129,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
@ -210,7 +168,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
@ -247,7 +205,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
@ -290,7 +248,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowSignups: true,
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
@ -343,7 +301,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
AllowSignups: true,
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowTeams: []coderd.GithubOAuth2Team{{"coder", "frontend"}},
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
State: &stateActive,
|
||||
@ -387,7 +345,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
AllowSignups: true,
|
||||
AllowOrganizations: []string{"coder", "nil"},
|
||||
AllowTeams: []coderd.GithubOAuth2Team{{"coder", "backend"}},
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{
|
||||
{
|
||||
@ -439,7 +397,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
AllowSignups: true,
|
||||
AllowOrganizations: []string{"coder", "nil"},
|
||||
AllowTeams: []coderd.GithubOAuth2Team{{"nil", "null"}},
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{
|
||||
{
|
||||
@ -490,7 +448,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
GithubOAuth2Config: &coderd.GithubOAuth2Config{
|
||||
AllowSignups: true,
|
||||
AllowEveryone: true,
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{}, nil
|
||||
},
|
||||
@ -529,7 +487,7 @@ func TestUserOAuth2Github(t *testing.T) {
|
||||
AllowSignups: true,
|
||||
AllowOrganizations: []string{"coder"},
|
||||
AllowTeams: []coderd.GithubOAuth2Team{{"coder", "frontend"}},
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
||||
return []*github.Membership{{
|
||||
State: &statePending,
|
||||
@ -830,7 +788,7 @@ func TestUserOIDC(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: &coderd.OIDCConfig{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
},
|
||||
})
|
||||
numLogs := len(auditor.AuditLogs)
|
||||
@ -854,8 +812,8 @@ func TestUserOIDC(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Auditor: auditor,
|
||||
OIDCConfig: &coderd.OIDCConfig{
|
||||
OAuth2Config: &oauth2Config{
|
||||
token: (&oauth2.Token{
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: (&oauth2.Token{
|
||||
AccessToken: "token",
|
||||
}).WithExtra(map[string]interface{}{
|
||||
"id_token": "invalid",
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@ -21,7 +20,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -1327,7 +1325,7 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request)
|
||||
continue
|
||||
}
|
||||
if gitAuthConfig.ValidateURL != "" {
|
||||
valid, err := validateGitToken(ctx, gitAuthConfig.ValidateURL, gitAuthLink.OAuthAccessToken)
|
||||
valid, err := gitAuthConfig.ValidateToken(ctx, gitAuthLink.OAuthAccessToken)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to validate git auth token",
|
||||
slog.F("workspace_owner_id", workspace.OwnerID.String()),
|
||||
@ -1373,7 +1371,7 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
gitAuthLink, updated, err := refreshGitToken(ctx, api.Database, workspace.OwnerID, gitAuthConfig, gitAuthLink)
|
||||
gitAuthLink, updated, err := gitAuthConfig.RefreshToken(ctx, api.Database, gitAuthLink)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to refresh git auth token.",
|
||||
@ -1390,74 +1388,6 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, formatGitAuthAccessToken(gitAuthConfig.Type, gitAuthLink.OAuthAccessToken))
|
||||
}
|
||||
|
||||
func refreshGitToken(ctx context.Context, db database.Store, owner uuid.UUID, gitAuthConfig *gitauth.Config, gitAuthLink database.GitAuthLink) (database.GitAuthLink, bool, error) {
|
||||
// If the token is expired and refresh is disabled, we prompt
|
||||
// the user to authenticate again.
|
||||
if gitAuthConfig.NoRefresh && gitAuthLink.OAuthExpiry.Before(database.Now()) {
|
||||
return gitAuthLink, false, nil
|
||||
}
|
||||
|
||||
token, err := gitAuthConfig.TokenSource(ctx, &oauth2.Token{
|
||||
AccessToken: gitAuthLink.OAuthAccessToken,
|
||||
RefreshToken: gitAuthLink.OAuthRefreshToken,
|
||||
Expiry: gitAuthLink.OAuthExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
return gitAuthLink, false, nil
|
||||
}
|
||||
|
||||
if gitAuthConfig.ValidateURL != "" {
|
||||
valid, err := validateGitToken(ctx, gitAuthConfig.ValidateURL, token.AccessToken)
|
||||
if err != nil {
|
||||
return gitAuthLink, false, xerrors.Errorf("validate git auth token: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
// The token is no longer valid!
|
||||
return gitAuthLink, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if token.AccessToken != gitAuthLink.OAuthAccessToken {
|
||||
// Update it
|
||||
gitAuthLink, err = db.UpdateGitAuthLink(ctx, database.UpdateGitAuthLinkParams{
|
||||
ProviderID: gitAuthConfig.ID,
|
||||
UserID: owner,
|
||||
UpdatedAt: database.Now(),
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
})
|
||||
if err != nil {
|
||||
return gitAuthLink, false, xerrors.Errorf("update git auth link: %w", err)
|
||||
}
|
||||
}
|
||||
return gitAuthLink, true, nil
|
||||
}
|
||||
|
||||
// validateGitToken ensures the git token provided is valid
|
||||
// against the provided URL.
|
||||
func validateGitToken(ctx context.Context, validateURL, token string) (bool, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, validateURL, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode == http.StatusUnauthorized {
|
||||
// The token is no longer valid!
|
||||
return false, nil
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
data, _ := io.ReadAll(res.Body)
|
||||
return false, xerrors.Errorf("status %d: body: %s", res.StatusCode, data)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Provider types have different username/password formats.
|
||||
func formatGitAuthAccessToken(typ codersdk.GitProvider, token string) agentsdk.GitAuthResponse {
|
||||
var resp agentsdk.GitAuthResponse
|
||||
|
@ -786,7 +786,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
GitAuthConfigs: []*gitauth.Config{{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ID: "github",
|
||||
Regex: regexp.MustCompile(`github\.com`),
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
@ -830,7 +830,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
GitAuthConfigs: []*gitauth.Config{{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ID: "github",
|
||||
Regex: regexp.MustCompile(`github\.com`),
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
@ -844,7 +844,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
GitAuthConfigs: []*gitauth.Config{{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ID: "github",
|
||||
Regex: regexp.MustCompile(`github\.com`),
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
@ -872,7 +872,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
|
||||
IncludeProvisionerDaemon: true,
|
||||
GitAuthConfigs: []*gitauth.Config{{
|
||||
ValidateURL: srv.URL,
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ID: "github",
|
||||
Regex: regexp.MustCompile(`github\.com`),
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
@ -923,8 +923,8 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
GitAuthConfigs: []*gitauth.Config{{
|
||||
OAuth2Config: &oauth2Config{
|
||||
token: &oauth2.Token{
|
||||
OAuth2Config: &testutil.OAuth2Config{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "token",
|
||||
RefreshToken: "something",
|
||||
Expiry: database.Now().Add(-time.Hour),
|
||||
@ -973,7 +973,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
GitAuthConfigs: []*gitauth.Config{{
|
||||
OAuth2Config: &oauth2Config{},
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
ID: "github",
|
||||
Regex: regexp.MustCompile(`github\.com`),
|
||||
Type: codersdk.GitProviderGitHub,
|
||||
@ -1011,7 +1011,7 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) {
|
||||
resp := coderdtest.RequestGitAuthCallback(t, "github", client)
|
||||
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
token = <-tokenChan
|
||||
require.Equal(t, "token", token.Username)
|
||||
require.Equal(t, "access_token", token.Username)
|
||||
|
||||
token, err = agentClient.GitAuth(context.Background(), "github.com/asd/asd", false)
|
||||
require.NoError(t, err)
|
||||
|
Reference in New Issue
Block a user