mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
feat: add support for optional external auth providers (#12021)
This commit is contained in:
committed by
GitHub
parent
78c9f82719
commit
475c3650ca
@ -501,10 +501,16 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err))
|
||||
}
|
||||
|
||||
externalAuthProviders := []*sdkproto.ExternalAuthProvider{}
|
||||
for _, p := range templateVersion.ExternalAuthProviders {
|
||||
dbExternalAuthProviders := []database.ExternalAuthProvider{}
|
||||
err = json.Unmarshal(templateVersion.ExternalAuthProviders, &dbExternalAuthProviders)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to deserialize external_auth_providers value: %w", err)
|
||||
}
|
||||
|
||||
externalAuthProviders := make([]*sdkproto.ExternalAuthProvider, 0, len(dbExternalAuthProviders))
|
||||
for _, p := range dbExternalAuthProviders {
|
||||
link, err := s.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
|
||||
ProviderID: p,
|
||||
ProviderID: p.ID,
|
||||
UserID: owner.ID,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@ -515,7 +521,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
}
|
||||
var config *externalauth.Config
|
||||
for _, c := range s.ExternalAuthConfigs {
|
||||
if c.ID != p {
|
||||
if c.ID != p.ID {
|
||||
continue
|
||||
}
|
||||
config = c
|
||||
@ -524,7 +530,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
// We weren't able to find a matching config for the ID!
|
||||
if config == nil {
|
||||
s.Logger.Warn(ctx, "workspace build job is missing external auth provider",
|
||||
slog.F("provider_id", p),
|
||||
slog.F("provider_id", p.ID),
|
||||
slog.F("template_version_id", templateVersion.ID),
|
||||
slog.F("workspace_id", workspaceBuild.WorkspaceID))
|
||||
continue
|
||||
@ -532,13 +538,13 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
|
||||
|
||||
link, valid, err := config.RefreshToken(ctx, s.Database, link)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("refresh external auth link %q: %s", p, err))
|
||||
return nil, failJob(fmt.Sprintf("refresh external auth link %q: %s", p.ID, err))
|
||||
}
|
||||
if !valid {
|
||||
continue
|
||||
}
|
||||
externalAuthProviders = append(externalAuthProviders, &sdkproto.ExternalAuthProvider{
|
||||
Id: p,
|
||||
Id: p.ID,
|
||||
AccessToken: link.OAuthAccessToken,
|
||||
})
|
||||
}
|
||||
@ -1133,23 +1139,49 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
for _, externalAuthProvider := range jobType.TemplateImport.ExternalAuthProviders {
|
||||
contains := false
|
||||
for _, configuredProvider := range s.ExternalAuthConfigs {
|
||||
if configuredProvider.ID == externalAuthProvider {
|
||||
if configuredProvider.ID == externalAuthProvider.Id {
|
||||
contains = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !contains {
|
||||
completedError = sql.NullString{
|
||||
String: fmt.Sprintf("external auth provider %q is not configured", externalAuthProvider),
|
||||
String: fmt.Sprintf("external auth provider %q is not configured", externalAuthProvider.Id),
|
||||
Valid: true,
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to `ExternalAuthProvidersNames` if it was specified and `ExternalAuthProviders`
|
||||
// was not. Gives us backwards compatibility with custom provisioners that haven't been
|
||||
// updated to use the new field yet.
|
||||
var externalAuthProviders []database.ExternalAuthProvider
|
||||
if providersLen := len(jobType.TemplateImport.ExternalAuthProviders); providersLen > 0 {
|
||||
externalAuthProviders = make([]database.ExternalAuthProvider, 0, providersLen)
|
||||
for _, provider := range jobType.TemplateImport.ExternalAuthProviders {
|
||||
externalAuthProviders = append(externalAuthProviders, database.ExternalAuthProvider{
|
||||
ID: provider.Id,
|
||||
Optional: provider.Optional,
|
||||
})
|
||||
}
|
||||
} else if namesLen := len(jobType.TemplateImport.ExternalAuthProvidersNames); namesLen > 0 {
|
||||
externalAuthProviders = make([]database.ExternalAuthProvider, 0, namesLen)
|
||||
for _, providerID := range jobType.TemplateImport.ExternalAuthProvidersNames {
|
||||
externalAuthProviders = append(externalAuthProviders, database.ExternalAuthProvider{
|
||||
ID: providerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
externalAuthProvidersMessage, err := json.Marshal(externalAuthProviders)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to serialize external_auth_providers value: %w", err)
|
||||
}
|
||||
|
||||
err = s.Database.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
|
||||
JobID: jobID,
|
||||
ExternalAuthProviders: jobType.TemplateImport.ExternalAuthProviders,
|
||||
ExternalAuthProviders: json.RawMessage(externalAuthProvidersMessage),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -172,11 +172,14 @@ func TestAcquireJob(t *testing.T) {
|
||||
// create an API key with an expiration within the bounds of the
|
||||
// deployment config.
|
||||
dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)}
|
||||
gitAuthProvider := "github"
|
||||
gitAuthProvider := &sdkproto.ExternalAuthProviderResource{
|
||||
Id: "github",
|
||||
}
|
||||
|
||||
srv, db, ps, _ := setup(t, false, &overrides{
|
||||
deploymentValues: dv,
|
||||
externalAuthConfigs: []*externalauth.Config{{
|
||||
ID: gitAuthProvider,
|
||||
ID: gitAuthProvider.Id,
|
||||
InstrumentedOAuth2Config: &testutil.OAuth2Config{},
|
||||
}},
|
||||
})
|
||||
@ -191,7 +194,7 @@ func TestAcquireJob(t *testing.T) {
|
||||
OAuthAccessToken: "access-token",
|
||||
})
|
||||
dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
|
||||
ProviderID: gitAuthProvider,
|
||||
ProviderID: gitAuthProvider.Id,
|
||||
UserID: user.ID,
|
||||
})
|
||||
template := dbgen.Template(t, db, database.Template{
|
||||
@ -207,9 +210,14 @@ func TestAcquireJob(t *testing.T) {
|
||||
},
|
||||
JobID: uuid.New(),
|
||||
})
|
||||
err := db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
|
||||
externalAuthProviders, err := json.Marshal([]database.ExternalAuthProvider{{
|
||||
ID: gitAuthProvider.Id,
|
||||
Optional: gitAuthProvider.Optional,
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
err = db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
|
||||
JobID: version.JobID,
|
||||
ExternalAuthProviders: []string{gitAuthProvider},
|
||||
ExternalAuthProviders: json.RawMessage(externalAuthProviders),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -321,7 +329,7 @@ func TestAcquireJob(t *testing.T) {
|
||||
},
|
||||
},
|
||||
ExternalAuthProviders: []*sdkproto.ExternalAuthProvider{{
|
||||
Id: gitAuthProvider,
|
||||
Id: gitAuthProvider.Id,
|
||||
AccessToken: "access_token",
|
||||
}},
|
||||
Metadata: &sdkproto.Metadata{
|
||||
@ -949,8 +957,10 @@ func TestCompleteJob(t *testing.T) {
|
||||
Name: "hello",
|
||||
Type: "aws_instance",
|
||||
}},
|
||||
StopResources: []*sdkproto.Resource{},
|
||||
ExternalAuthProviders: []string{"github"},
|
||||
StopResources: []*sdkproto.Resource{},
|
||||
ExternalAuthProviders: []*sdkproto.ExternalAuthProviderResource{{
|
||||
Id: "github",
|
||||
}},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ -1002,7 +1012,7 @@ func TestCompleteJob(t *testing.T) {
|
||||
Type: "aws_instance",
|
||||
}},
|
||||
StopResources: []*sdkproto.Resource{},
|
||||
ExternalAuthProviders: []string{"github"},
|
||||
ExternalAuthProviders: []*sdkproto.ExternalAuthProviderResource{{Id: "github"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ -1776,7 +1786,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
|
||||
Tags: database.StringMap{},
|
||||
LastSeenAt: sql.NullTime{},
|
||||
Version: buildinfo.Version(),
|
||||
APIVersion: proto.VersionCurrent.String(),
|
||||
APIVersion: proto.CurrentVersion.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
Reference in New Issue
Block a user