feat: add support for optional external auth providers (#12021)

This commit is contained in:
Kayla Washburn-Love
2024-02-21 11:18:38 -07:00
committed by GitHub
parent 78c9f82719
commit 475c3650ca
39 changed files with 1495 additions and 727 deletions

View File

@ -501,10 +501,16 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
return nil, failJob(fmt.Sprintf("get workspace build parameters: %s", err))
}
externalAuthProviders := []*sdkproto.ExternalAuthProvider{}
for _, p := range templateVersion.ExternalAuthProviders {
dbExternalAuthProviders := []database.ExternalAuthProvider{}
err = json.Unmarshal(templateVersion.ExternalAuthProviders, &dbExternalAuthProviders)
if err != nil {
return nil, xerrors.Errorf("failed to deserialize external_auth_providers value: %w", err)
}
externalAuthProviders := make([]*sdkproto.ExternalAuthProvider, 0, len(dbExternalAuthProviders))
for _, p := range dbExternalAuthProviders {
link, err := s.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: p,
ProviderID: p.ID,
UserID: owner.ID,
})
if errors.Is(err, sql.ErrNoRows) {
@ -515,7 +521,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
}
var config *externalauth.Config
for _, c := range s.ExternalAuthConfigs {
if c.ID != p {
if c.ID != p.ID {
continue
}
config = c
@ -524,7 +530,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
// We weren't able to find a matching config for the ID!
if config == nil {
s.Logger.Warn(ctx, "workspace build job is missing external auth provider",
slog.F("provider_id", p),
slog.F("provider_id", p.ID),
slog.F("template_version_id", templateVersion.ID),
slog.F("workspace_id", workspaceBuild.WorkspaceID))
continue
@ -532,13 +538,13 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
link, valid, err := config.RefreshToken(ctx, s.Database, link)
if err != nil {
return nil, failJob(fmt.Sprintf("refresh external auth link %q: %s", p, err))
return nil, failJob(fmt.Sprintf("refresh external auth link %q: %s", p.ID, err))
}
if !valid {
continue
}
externalAuthProviders = append(externalAuthProviders, &sdkproto.ExternalAuthProvider{
Id: p,
Id: p.ID,
AccessToken: link.OAuthAccessToken,
})
}
@ -1133,23 +1139,49 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
for _, externalAuthProvider := range jobType.TemplateImport.ExternalAuthProviders {
contains := false
for _, configuredProvider := range s.ExternalAuthConfigs {
if configuredProvider.ID == externalAuthProvider {
if configuredProvider.ID == externalAuthProvider.Id {
contains = true
break
}
}
if !contains {
completedError = sql.NullString{
String: fmt.Sprintf("external auth provider %q is not configured", externalAuthProvider),
String: fmt.Sprintf("external auth provider %q is not configured", externalAuthProvider.Id),
Valid: true,
}
break
}
}
// Fallback to `ExternalAuthProvidersNames` if it was specified and `ExternalAuthProviders`
// was not. Gives us backwards compatibility with custom provisioners that haven't been
// updated to use the new field yet.
var externalAuthProviders []database.ExternalAuthProvider
if providersLen := len(jobType.TemplateImport.ExternalAuthProviders); providersLen > 0 {
externalAuthProviders = make([]database.ExternalAuthProvider, 0, providersLen)
for _, provider := range jobType.TemplateImport.ExternalAuthProviders {
externalAuthProviders = append(externalAuthProviders, database.ExternalAuthProvider{
ID: provider.Id,
Optional: provider.Optional,
})
}
} else if namesLen := len(jobType.TemplateImport.ExternalAuthProvidersNames); namesLen > 0 {
externalAuthProviders = make([]database.ExternalAuthProvider, 0, namesLen)
for _, providerID := range jobType.TemplateImport.ExternalAuthProvidersNames {
externalAuthProviders = append(externalAuthProviders, database.ExternalAuthProvider{
ID: providerID,
})
}
}
externalAuthProvidersMessage, err := json.Marshal(externalAuthProviders)
if err != nil {
return nil, xerrors.Errorf("failed to serialize external_auth_providers value: %w", err)
}
err = s.Database.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
JobID: jobID,
ExternalAuthProviders: jobType.TemplateImport.ExternalAuthProviders,
ExternalAuthProviders: json.RawMessage(externalAuthProvidersMessage),
UpdatedAt: dbtime.Now(),
})
if err != nil {

View File

@ -172,11 +172,14 @@ func TestAcquireJob(t *testing.T) {
// create an API key with an expiration within the bounds of the
// deployment config.
dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)}
gitAuthProvider := "github"
gitAuthProvider := &sdkproto.ExternalAuthProviderResource{
Id: "github",
}
srv, db, ps, _ := setup(t, false, &overrides{
deploymentValues: dv,
externalAuthConfigs: []*externalauth.Config{{
ID: gitAuthProvider,
ID: gitAuthProvider.Id,
InstrumentedOAuth2Config: &testutil.OAuth2Config{},
}},
})
@ -191,7 +194,7 @@ func TestAcquireJob(t *testing.T) {
OAuthAccessToken: "access-token",
})
dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
ProviderID: gitAuthProvider,
ProviderID: gitAuthProvider.Id,
UserID: user.ID,
})
template := dbgen.Template(t, db, database.Template{
@ -207,9 +210,14 @@ func TestAcquireJob(t *testing.T) {
},
JobID: uuid.New(),
})
err := db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
externalAuthProviders, err := json.Marshal([]database.ExternalAuthProvider{{
ID: gitAuthProvider.Id,
Optional: gitAuthProvider.Optional,
}})
require.NoError(t, err)
err = db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
JobID: version.JobID,
ExternalAuthProviders: []string{gitAuthProvider},
ExternalAuthProviders: json.RawMessage(externalAuthProviders),
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err)
@ -321,7 +329,7 @@ func TestAcquireJob(t *testing.T) {
},
},
ExternalAuthProviders: []*sdkproto.ExternalAuthProvider{{
Id: gitAuthProvider,
Id: gitAuthProvider.Id,
AccessToken: "access_token",
}},
Metadata: &sdkproto.Metadata{
@ -949,8 +957,10 @@ func TestCompleteJob(t *testing.T) {
Name: "hello",
Type: "aws_instance",
}},
StopResources: []*sdkproto.Resource{},
ExternalAuthProviders: []string{"github"},
StopResources: []*sdkproto.Resource{},
ExternalAuthProviders: []*sdkproto.ExternalAuthProviderResource{{
Id: "github",
}},
},
},
})
@ -1002,7 +1012,7 @@ func TestCompleteJob(t *testing.T) {
Type: "aws_instance",
}},
StopResources: []*sdkproto.Resource{},
ExternalAuthProviders: []string{"github"},
ExternalAuthProviders: []*sdkproto.ExternalAuthProviderResource{{Id: "github"}},
},
},
})
@ -1776,7 +1786,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
Tags: database.StringMap{},
LastSeenAt: sql.NullTime{},
Version: buildinfo.Version(),
APIVersion: proto.VersionCurrent.String(),
APIVersion: proto.CurrentVersion.String(),
})
require.NoError(t, err)