mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
feat: improve transaction safety in CompleteJob function (#17970)
This PR refactors the CompleteJob function to use database transactions more consistently for better atomicity guarantees. The large function was broken down into three specialized handlers: - completeTemplateImportJob - completeWorkspaceBuildJob - completeTemplateDryRunJob Each handler now uses the Database.InTx wrapper to ensure all database operations for a job completion are performed within a single transaction, preventing partial updates in case of failures. Added comprehensive tests for transaction behavior for each job type. Fixes #17694 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@ -1340,14 +1340,56 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
|
||||
switch jobType := completed.Type.(type) {
|
||||
case *proto.CompletedJob_TemplateImport_:
|
||||
var input TemplateVersionImportJob
|
||||
err = json.Unmarshal(job.Input, &input)
|
||||
err = s.completeTemplateImportJob(ctx, job, jobID, jobType, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("template version ID is expected: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
case *proto.CompletedJob_WorkspaceBuild_:
|
||||
err = s.completeWorkspaceBuildJob(ctx, job, jobID, jobType, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case *proto.CompletedJob_TemplateDryRun_:
|
||||
err = s.completeTemplateDryRunJob(ctx, job, jobID, jobType, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
if completed.Type == nil {
|
||||
return nil, xerrors.Errorf("type payload must be provided")
|
||||
}
|
||||
return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match",
|
||||
reflect.TypeOf(completed.Type).String())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
}
|
||||
|
||||
s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID))
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// completeTemplateImportJob handles completion of a template import job.
|
||||
// All database operations are performed within a transaction.
|
||||
func (s *server) completeTemplateImportJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_TemplateImport_, telemetrySnapshot *telemetry.Snapshot) error {
|
||||
var input TemplateVersionImportJob
|
||||
err := json.Unmarshal(job.Input, &input)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("template version ID is expected: %w", err)
|
||||
}
|
||||
|
||||
// Execute all database operations in a transaction
|
||||
return s.Database.InTx(func(db database.Store) error {
|
||||
now := s.timeNow()
|
||||
|
||||
// Process resources
|
||||
for transition, resources := range map[database.WorkspaceTransition][]*sdkproto.Resource{
|
||||
database.WorkspaceTransitionStart: jobType.TemplateImport.StartResources,
|
||||
database.WorkspaceTransitionStop: jobType.TemplateImport.StopResources,
|
||||
@ -1359,11 +1401,13 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
slog.F("resource_type", resource.Type),
|
||||
slog.F("transition", transition))
|
||||
|
||||
if err := InsertWorkspaceResource(ctx, s.Database, jobID, transition, resource, telemetrySnapshot); err != nil {
|
||||
return nil, xerrors.Errorf("insert resource: %w", err)
|
||||
if err := InsertWorkspaceResource(ctx, db, jobID, transition, resource, telemetrySnapshot); err != nil {
|
||||
return xerrors.Errorf("insert resource: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process modules
|
||||
for transition, modules := range map[database.WorkspaceTransition][]*sdkproto.Module{
|
||||
database.WorkspaceTransitionStart: jobType.TemplateImport.StartModules,
|
||||
database.WorkspaceTransitionStop: jobType.TemplateImport.StopModules,
|
||||
@ -1376,12 +1420,13 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
slog.F("module_key", module.Key),
|
||||
slog.F("transition", transition))
|
||||
|
||||
if err := InsertWorkspaceModule(ctx, s.Database, jobID, transition, module, telemetrySnapshot); err != nil {
|
||||
return nil, xerrors.Errorf("insert module: %w", err)
|
||||
if err := InsertWorkspaceModule(ctx, db, jobID, transition, module, telemetrySnapshot); err != nil {
|
||||
return xerrors.Errorf("insert module: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process rich parameters
|
||||
for _, richParameter := range jobType.TemplateImport.RichParameters {
|
||||
s.Logger.Info(ctx, "inserting template import job parameter",
|
||||
slog.F("job_id", job.ID.String()),
|
||||
@ -1391,7 +1436,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
)
|
||||
options, err := json.Marshal(richParameter.Options)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal parameter options: %w", err)
|
||||
return xerrors.Errorf("marshal parameter options: %w", err)
|
||||
}
|
||||
|
||||
var validationMin, validationMax sql.NullInt32
|
||||
@ -1408,7 +1453,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = s.Database.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{
|
||||
_, err = db.InsertTemplateVersionParameter(ctx, database.InsertTemplateVersionParameterParams{
|
||||
TemplateVersionID: input.TemplateVersionID,
|
||||
Name: richParameter.Name,
|
||||
DisplayName: richParameter.DisplayName,
|
||||
@ -1428,15 +1473,17 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
Ephemeral: richParameter.Ephemeral,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert parameter: %w", err)
|
||||
return xerrors.Errorf("insert parameter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = InsertWorkspacePresetsAndParameters(ctx, s.Logger, s.Database, jobID, input.TemplateVersionID, jobType.TemplateImport.Presets, now)
|
||||
// Process presets and parameters
|
||||
err := InsertWorkspacePresetsAndParameters(ctx, s.Logger, db, jobID, input.TemplateVersionID, jobType.TemplateImport.Presets, now)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert workspace presets and parameters: %w", err)
|
||||
return xerrors.Errorf("insert workspace presets and parameters: %w", err)
|
||||
}
|
||||
|
||||
// Process external auth providers
|
||||
var completedError sql.NullString
|
||||
|
||||
for _, externalAuthProvider := range jobType.TemplateImport.ExternalAuthProviders {
|
||||
@ -1479,18 +1526,19 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
|
||||
externalAuthProvidersMessage, err := json.Marshal(externalAuthProviders)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to serialize external_auth_providers value: %w", err)
|
||||
return xerrors.Errorf("failed to serialize external_auth_providers value: %w", err)
|
||||
}
|
||||
|
||||
err = s.Database.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
|
||||
err = db.UpdateTemplateVersionExternalAuthProvidersByJobID(ctx, database.UpdateTemplateVersionExternalAuthProvidersByJobIDParams{
|
||||
JobID: jobID,
|
||||
ExternalAuthProviders: externalAuthProvidersMessage,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update template version external auth providers: %w", err)
|
||||
return xerrors.Errorf("update template version external auth providers: %w", err)
|
||||
}
|
||||
|
||||
// Process terraform values
|
||||
plan := jobType.TemplateImport.Plan
|
||||
moduleFiles := jobType.TemplateImport.ModuleFiles
|
||||
// If there is a plan, or a module files archive we need to insert a
|
||||
@ -1509,7 +1557,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
hash := hex.EncodeToString(hashBytes[:])
|
||||
|
||||
// nolint:gocritic // Requires reading "system" files
|
||||
file, err := s.Database.GetFileByHashAndCreator(dbauthz.AsSystemRestricted(ctx), database.GetFileByHashAndCreatorParams{Hash: hash, CreatedBy: uuid.Nil})
|
||||
file, err := db.GetFileByHashAndCreator(dbauthz.AsSystemRestricted(ctx), database.GetFileByHashAndCreatorParams{Hash: hash, CreatedBy: uuid.Nil})
|
||||
switch {
|
||||
case err == nil:
|
||||
// This set of modules is already cached, which means we can reuse them
|
||||
@ -1518,10 +1566,10 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
UUID: file.ID,
|
||||
}
|
||||
case !xerrors.Is(err, sql.ErrNoRows):
|
||||
return nil, xerrors.Errorf("check for cached modules: %w", err)
|
||||
return xerrors.Errorf("check for cached modules: %w", err)
|
||||
default:
|
||||
// nolint:gocritic // Requires creating a "system" file
|
||||
file, err = s.Database.InsertFile(dbauthz.AsSystemRestricted(ctx), database.InsertFileParams{
|
||||
file, err = db.InsertFile(dbauthz.AsSystemRestricted(ctx), database.InsertFileParams{
|
||||
ID: uuid.New(),
|
||||
Hash: hash,
|
||||
CreatedBy: uuid.Nil,
|
||||
@ -1530,7 +1578,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
Data: moduleFiles,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert template version terraform modules: %w", err)
|
||||
return xerrors.Errorf("insert template version terraform modules: %w", err)
|
||||
}
|
||||
fileID = uuid.NullUUID{
|
||||
Valid: true,
|
||||
@ -1539,7 +1587,7 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Database.InsertTemplateVersionTerraformValuesByJobID(ctx, database.InsertTemplateVersionTerraformValuesByJobIDParams{
|
||||
err = db.InsertTemplateVersionTerraformValuesByJobID(ctx, database.InsertTemplateVersionTerraformValuesByJobIDParams{
|
||||
JobID: jobID,
|
||||
UpdatedAt: now,
|
||||
CachedPlan: plan,
|
||||
@ -1547,11 +1595,12 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
ProvisionerdVersion: s.apiVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert template version terraform data: %w", err)
|
||||
return xerrors.Errorf("insert template version terraform data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
// Mark job as completed
|
||||
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: jobID,
|
||||
UpdatedAt: now,
|
||||
CompletedAt: sql.NullTime{
|
||||
@ -1562,206 +1611,136 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
ErrorCode: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update provisioner job: %w", err)
|
||||
return xerrors.Errorf("update provisioner job: %w", err)
|
||||
}
|
||||
s.Logger.Debug(ctx, "marked import job as completed", slog.F("job_id", jobID))
|
||||
|
||||
case *proto.CompletedJob_WorkspaceBuild_:
|
||||
var input WorkspaceProvisionJob
|
||||
err = json.Unmarshal(job.Input, &input)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal job data: %w", err)
|
||||
return nil
|
||||
}, nil) // End of transaction
|
||||
}
|
||||
|
||||
// completeWorkspaceBuildJob handles completion of a workspace build job.
|
||||
// Most database operations are performed within a transaction.
|
||||
func (s *server) completeWorkspaceBuildJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_WorkspaceBuild_, telemetrySnapshot *telemetry.Snapshot) error {
|
||||
var input WorkspaceProvisionJob
|
||||
err := json.Unmarshal(job.Input, &input)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("unmarshal job data: %w", err)
|
||||
}
|
||||
|
||||
workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace build: %w", err)
|
||||
}
|
||||
|
||||
var workspace database.Workspace
|
||||
var getWorkspaceError error
|
||||
|
||||
// Execute all database modifications in a transaction
|
||||
err = s.Database.InTx(func(db database.Store) error {
|
||||
// It's important we use s.timeNow() here because we want to be
|
||||
// able to customize the current time from within tests.
|
||||
now := s.timeNow()
|
||||
|
||||
workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
|
||||
if getWorkspaceError != nil {
|
||||
s.Logger.Error(ctx,
|
||||
"fetch workspace for build",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.F("workspace_id", workspaceBuild.WorkspaceID),
|
||||
)
|
||||
return getWorkspaceError
|
||||
}
|
||||
|
||||
workspaceBuild, err := s.Database.GetWorkspaceBuildByID(ctx, input.WorkspaceBuildID)
|
||||
templateScheduleStore := *s.TemplateScheduleStore.Load()
|
||||
|
||||
autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{
|
||||
Database: db,
|
||||
TemplateScheduleStore: templateScheduleStore,
|
||||
UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(),
|
||||
Now: now,
|
||||
Workspace: workspace.WorkspaceTable(),
|
||||
// Allowed to be the empty string.
|
||||
WorkspaceAutostart: workspace.AutostartSchedule.String,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get workspace build: %w", err)
|
||||
return xerrors.Errorf("calculate auto stop: %w", err)
|
||||
}
|
||||
|
||||
var workspace database.Workspace
|
||||
var getWorkspaceError error
|
||||
|
||||
err = s.Database.InTx(func(db database.Store) error {
|
||||
// It's important we use s.timeNow() here because we want to be
|
||||
// able to customize the current time from within tests.
|
||||
now := s.timeNow()
|
||||
|
||||
workspace, getWorkspaceError = db.GetWorkspaceByID(ctx, workspaceBuild.WorkspaceID)
|
||||
if getWorkspaceError != nil {
|
||||
s.Logger.Error(ctx,
|
||||
"fetch workspace for build",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.F("workspace_id", workspaceBuild.WorkspaceID),
|
||||
)
|
||||
return getWorkspaceError
|
||||
}
|
||||
|
||||
templateScheduleStore := *s.TemplateScheduleStore.Load()
|
||||
|
||||
autoStop, err := schedule.CalculateAutostop(ctx, schedule.CalculateAutostopParams{
|
||||
Database: db,
|
||||
TemplateScheduleStore: templateScheduleStore,
|
||||
UserQuietHoursScheduleStore: *s.UserQuietHoursScheduleStore.Load(),
|
||||
Now: now,
|
||||
Workspace: workspace.WorkspaceTable(),
|
||||
// Allowed to be the empty string.
|
||||
WorkspaceAutostart: workspace.AutostartSchedule.String,
|
||||
})
|
||||
if workspace.AutostartSchedule.Valid {
|
||||
templateScheduleOptions, err := templateScheduleStore.Get(ctx, db, workspace.TemplateID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("calculate auto stop: %w", err)
|
||||
return xerrors.Errorf("get template schedule options: %w", err)
|
||||
}
|
||||
|
||||
if workspace.AutostartSchedule.Valid {
|
||||
templateScheduleOptions, err := templateScheduleStore.Get(ctx, db, workspace.TemplateID)
|
||||
nextStartAt, err := schedule.NextAllowedAutostart(now, workspace.AutostartSchedule.String, templateScheduleOptions)
|
||||
if err == nil {
|
||||
err = db.UpdateWorkspaceNextStartAt(ctx, database.UpdateWorkspaceNextStartAtParams{
|
||||
ID: workspace.ID,
|
||||
NextStartAt: sql.NullTime{Valid: true, Time: nextStartAt.UTC()},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template schedule options: %w", err)
|
||||
}
|
||||
|
||||
nextStartAt, err := schedule.NextAllowedAutostart(now, workspace.AutostartSchedule.String, templateScheduleOptions)
|
||||
if err == nil {
|
||||
err = db.UpdateWorkspaceNextStartAt(ctx, database.UpdateWorkspaceNextStartAtParams{
|
||||
ID: workspace.ID,
|
||||
NextStartAt: sql.NullTime{Valid: true, Time: nextStartAt.UTC()},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace next start at: %w", err)
|
||||
}
|
||||
return xerrors.Errorf("update workspace next start at: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: jobID,
|
||||
UpdatedAt: now,
|
||||
CompletedAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
Error: sql.NullString{},
|
||||
ErrorCode: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update provisioner job: %w", err)
|
||||
}
|
||||
err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
|
||||
ID: workspaceBuild.ID,
|
||||
ProvisionerState: jobType.WorkspaceBuild.State,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace build provisioner state: %w", err)
|
||||
}
|
||||
err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{
|
||||
ID: workspaceBuild.ID,
|
||||
Deadline: autoStop.Deadline,
|
||||
MaxDeadline: autoStop.MaxDeadline,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace build deadline: %w", err)
|
||||
}
|
||||
|
||||
agentTimeouts := make(map[time.Duration]bool) // A set of agent timeouts.
|
||||
// This could be a bulk insert to improve performance.
|
||||
for _, protoResource := range jobType.WorkspaceBuild.Resources {
|
||||
for _, protoAgent := range protoResource.Agents {
|
||||
dur := time.Duration(protoAgent.GetConnectionTimeoutSeconds()) * time.Second
|
||||
agentTimeouts[dur] = true
|
||||
}
|
||||
|
||||
err = InsertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert provisioner job: %w", err)
|
||||
}
|
||||
}
|
||||
for _, module := range jobType.WorkspaceBuild.Modules {
|
||||
if err := InsertWorkspaceModule(ctx, db, job.ID, workspaceBuild.Transition, module, telemetrySnapshot); err != nil {
|
||||
return xerrors.Errorf("insert provisioner job module: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// On start, we want to ensure that workspace agents timeout statuses
|
||||
// are propagated. This method is simple and does not protect against
|
||||
// notifying in edge cases like when a workspace is stopped soon
|
||||
// after being started.
|
||||
//
|
||||
// Agent timeouts could be minutes apart, resulting in an unresponsive
|
||||
// experience, so we'll notify after every unique timeout seconds.
|
||||
if !input.DryRun && workspaceBuild.Transition == database.WorkspaceTransitionStart && len(agentTimeouts) > 0 {
|
||||
timeouts := maps.Keys(agentTimeouts)
|
||||
slices.Sort(timeouts)
|
||||
|
||||
var updates []<-chan time.Time
|
||||
for _, d := range timeouts {
|
||||
s.Logger.Debug(ctx, "triggering workspace notification after agent timeout",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.F("timeout", d),
|
||||
)
|
||||
// Agents are inserted with `dbtime.Now()`, this triggers a
|
||||
// workspace event approximately after created + timeout seconds.
|
||||
updates = append(updates, time.After(d))
|
||||
}
|
||||
go func() {
|
||||
for _, wait := range updates {
|
||||
select {
|
||||
case <-s.lifecycleCtx.Done():
|
||||
// If the server is shutting down, we don't want to wait around.
|
||||
s.Logger.Debug(ctx, "stopping notifications due to server shutdown",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
)
|
||||
return
|
||||
case <-wait:
|
||||
// Wait for the next potential timeout to occur.
|
||||
msg, err := json.Marshal(wspubsub.WorkspaceEvent{
|
||||
Kind: wspubsub.WorkspaceEventKindAgentTimeout,
|
||||
WorkspaceID: workspace.ID,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "marshal workspace update event", slog.Error(err))
|
||||
break
|
||||
}
|
||||
if err := s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg); err != nil {
|
||||
if s.lifecycleCtx.Err() != nil {
|
||||
// If the server is shutting down, we don't want to log this error, nor wait around.
|
||||
s.Logger.Debug(ctx, "stopping notifications due to server shutdown",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
)
|
||||
return
|
||||
}
|
||||
s.Logger.Error(ctx, "workspace notification after agent timeout failed",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if workspaceBuild.Transition != database.WorkspaceTransitionDelete {
|
||||
// This is for deleting a workspace!
|
||||
return nil
|
||||
}
|
||||
|
||||
err = db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{
|
||||
ID: workspaceBuild.WorkspaceID,
|
||||
Deleted: true,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace deleted: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("complete job: %w", err)
|
||||
}
|
||||
|
||||
// Insert timings outside transaction since it is metadata.
|
||||
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: jobID,
|
||||
UpdatedAt: now,
|
||||
CompletedAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
Error: sql.NullString{},
|
||||
ErrorCode: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update provisioner job: %w", err)
|
||||
}
|
||||
err = db.UpdateWorkspaceBuildProvisionerStateByID(ctx, database.UpdateWorkspaceBuildProvisionerStateByIDParams{
|
||||
ID: workspaceBuild.ID,
|
||||
ProvisionerState: jobType.WorkspaceBuild.State,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace build provisioner state: %w", err)
|
||||
}
|
||||
err = db.UpdateWorkspaceBuildDeadlineByID(ctx, database.UpdateWorkspaceBuildDeadlineByIDParams{
|
||||
ID: workspaceBuild.ID,
|
||||
Deadline: autoStop.Deadline,
|
||||
MaxDeadline: autoStop.MaxDeadline,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace build deadline: %w", err)
|
||||
}
|
||||
|
||||
agentTimeouts := make(map[time.Duration]bool) // A set of agent timeouts.
|
||||
// This could be a bulk insert to improve performance.
|
||||
for _, protoResource := range jobType.WorkspaceBuild.Resources {
|
||||
for _, protoAgent := range protoResource.Agents {
|
||||
dur := time.Duration(protoAgent.GetConnectionTimeoutSeconds()) * time.Second
|
||||
agentTimeouts[dur] = true
|
||||
}
|
||||
|
||||
err = InsertWorkspaceResource(ctx, db, job.ID, workspaceBuild.Transition, protoResource, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert provisioner job: %w", err)
|
||||
}
|
||||
}
|
||||
for _, module := range jobType.WorkspaceBuild.Modules {
|
||||
if err := InsertWorkspaceModule(ctx, db, job.ID, workspaceBuild.Transition, module, telemetrySnapshot); err != nil {
|
||||
return xerrors.Errorf("insert provisioner job module: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Insert timings inside the transaction now
|
||||
// nolint:exhaustruct // The other fields are set further down.
|
||||
params := database.InsertProvisionerJobTimingsParams{
|
||||
JobID: jobID,
|
||||
}
|
||||
for _, t := range completed.GetWorkspaceBuild().GetTimings() {
|
||||
for _, t := range jobType.WorkspaceBuild.Timings {
|
||||
if t.Start == nil || t.End == nil {
|
||||
s.Logger.Warn(ctx, "timings entry has nil start or end time", slog.F("entry", t.String()))
|
||||
continue
|
||||
@ -1780,153 +1759,229 @@ func (s *server) CompleteJob(ctx context.Context, completed *proto.CompletedJob)
|
||||
params.StartedAt = append(params.StartedAt, t.Start.AsTime())
|
||||
params.EndedAt = append(params.EndedAt, t.End.AsTime())
|
||||
}
|
||||
_, err = s.Database.InsertProvisionerJobTimings(ctx, params)
|
||||
_, err = db.InsertProvisionerJobTimings(ctx, params)
|
||||
if err != nil {
|
||||
// Don't fail the transaction for non-critical data.
|
||||
// Log error but don't fail the whole transaction for non-critical data
|
||||
s.Logger.Warn(ctx, "failed to update provisioner job timings", slog.F("job_id", jobID), slog.Error(err))
|
||||
}
|
||||
|
||||
// audit the outcome of the workspace build
|
||||
if getWorkspaceError == nil {
|
||||
// If the workspace has been deleted, notify the owner about it.
|
||||
if workspaceBuild.Transition == database.WorkspaceTransitionDelete {
|
||||
s.notifyWorkspaceDeleted(ctx, workspace, workspaceBuild)
|
||||
// On start, we want to ensure that workspace agents timeout statuses
|
||||
// are propagated. This method is simple and does not protect against
|
||||
// notifying in edge cases like when a workspace is stopped soon
|
||||
// after being started.
|
||||
//
|
||||
// Agent timeouts could be minutes apart, resulting in an unresponsive
|
||||
// experience, so we'll notify after every unique timeout seconds.
|
||||
if !input.DryRun && workspaceBuild.Transition == database.WorkspaceTransitionStart && len(agentTimeouts) > 0 {
|
||||
timeouts := maps.Keys(agentTimeouts)
|
||||
slices.Sort(timeouts)
|
||||
|
||||
var updates []<-chan time.Time
|
||||
for _, d := range timeouts {
|
||||
s.Logger.Debug(ctx, "triggering workspace notification after agent timeout",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.F("timeout", d),
|
||||
)
|
||||
// Agents are inserted with `dbtime.Now()`, this triggers a
|
||||
// workspace event approximately after created + timeout seconds.
|
||||
updates = append(updates, time.After(d))
|
||||
}
|
||||
|
||||
auditor := s.Auditor.Load()
|
||||
auditAction := auditActionFromTransition(workspaceBuild.Transition)
|
||||
|
||||
previousBuildNumber := workspaceBuild.BuildNumber - 1
|
||||
previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: previousBuildNumber,
|
||||
})
|
||||
if prevBuildErr != nil {
|
||||
previousBuild = database.WorkspaceBuild{}
|
||||
}
|
||||
|
||||
// We pass the below information to the Auditor so that it
|
||||
// can form a friendly string for the user to view in the UI.
|
||||
buildResourceInfo := audit.AdditionalFields{
|
||||
WorkspaceName: workspace.Name,
|
||||
BuildNumber: strconv.FormatInt(int64(workspaceBuild.BuildNumber), 10),
|
||||
BuildReason: database.BuildReason(string(workspaceBuild.Reason)),
|
||||
WorkspaceID: workspace.ID,
|
||||
}
|
||||
|
||||
wriBytes, err := json.Marshal(buildResourceInfo)
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err))
|
||||
}
|
||||
|
||||
bag := audit.BaggageFromContext(ctx)
|
||||
|
||||
audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceBuild]{
|
||||
Audit: *auditor,
|
||||
Log: s.Logger,
|
||||
UserID: job.InitiatorID,
|
||||
OrganizationID: workspace.OrganizationID,
|
||||
RequestID: job.ID,
|
||||
IP: bag.IP,
|
||||
Action: auditAction,
|
||||
Old: previousBuild,
|
||||
New: workspaceBuild,
|
||||
Status: http.StatusOK,
|
||||
AdditionalFields: wriBytes,
|
||||
})
|
||||
go func() {
|
||||
for _, wait := range updates {
|
||||
select {
|
||||
case <-s.lifecycleCtx.Done():
|
||||
// If the server is shutting down, we don't want to wait around.
|
||||
s.Logger.Debug(ctx, "stopping notifications due to server shutdown",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
)
|
||||
return
|
||||
case <-wait:
|
||||
// Wait for the next potential timeout to occur.
|
||||
msg, err := json.Marshal(wspubsub.WorkspaceEvent{
|
||||
Kind: wspubsub.WorkspaceEventKindAgentTimeout,
|
||||
WorkspaceID: workspace.ID,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "marshal workspace update event", slog.Error(err))
|
||||
break
|
||||
}
|
||||
if err := s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg); err != nil {
|
||||
if s.lifecycleCtx.Err() != nil {
|
||||
// If the server is shutting down, we don't want to log this error, nor wait around.
|
||||
s.Logger.Debug(ctx, "stopping notifications due to server shutdown",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
)
|
||||
return
|
||||
}
|
||||
s.Logger.Error(ctx, "workspace notification after agent timeout failed",
|
||||
slog.F("workspace_build_id", workspaceBuild.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
// Track resource replacements, if there are any.
|
||||
orchestrator := s.PrebuildsOrchestrator.Load()
|
||||
if resourceReplacements := completed.GetWorkspaceBuild().GetResourceReplacements(); orchestrator != nil && len(resourceReplacements) > 0 {
|
||||
// Fire and forget. Bind to the lifecycle of the server so shutdowns are handled gracefully.
|
||||
go (*orchestrator).TrackResourceReplacement(s.lifecycleCtx, workspace.ID, workspaceBuild.ID, resourceReplacements)
|
||||
}
|
||||
if workspaceBuild.Transition != database.WorkspaceTransitionDelete {
|
||||
// This is for deleting a workspace!
|
||||
return nil
|
||||
}
|
||||
|
||||
msg, err := json.Marshal(wspubsub.WorkspaceEvent{
|
||||
Kind: wspubsub.WorkspaceEventKindStateChange,
|
||||
WorkspaceID: workspace.ID,
|
||||
err = db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{
|
||||
ID: workspaceBuild.WorkspaceID,
|
||||
Deleted: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal workspace update event: %s", err)
|
||||
return xerrors.Errorf("update workspace deleted: %w", err)
|
||||
}
|
||||
err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg)
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("complete job: %w", err)
|
||||
}
|
||||
|
||||
// Post-transaction operations (operations that do not require transactions or
|
||||
// are external to the database, like audit logging, notifications, etc.)
|
||||
|
||||
// audit the outcome of the workspace build
|
||||
if getWorkspaceError == nil {
|
||||
// If the workspace has been deleted, notify the owner about it.
|
||||
if workspaceBuild.Transition == database.WorkspaceTransitionDelete {
|
||||
s.notifyWorkspaceDeleted(ctx, workspace, workspaceBuild)
|
||||
}
|
||||
|
||||
auditor := s.Auditor.Load()
|
||||
auditAction := auditActionFromTransition(workspaceBuild.Transition)
|
||||
|
||||
previousBuildNumber := workspaceBuild.BuildNumber - 1
|
||||
previousBuild, prevBuildErr := s.Database.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: previousBuildNumber,
|
||||
})
|
||||
if prevBuildErr != nil {
|
||||
previousBuild = database.WorkspaceBuild{}
|
||||
}
|
||||
|
||||
// We pass the below information to the Auditor so that it
|
||||
// can form a friendly string for the user to view in the UI.
|
||||
buildResourceInfo := audit.AdditionalFields{
|
||||
WorkspaceName: workspace.Name,
|
||||
BuildNumber: strconv.FormatInt(int64(workspaceBuild.BuildNumber), 10),
|
||||
BuildReason: database.BuildReason(string(workspaceBuild.Reason)),
|
||||
WorkspaceID: workspace.ID,
|
||||
}
|
||||
|
||||
wriBytes, err := json.Marshal(buildResourceInfo)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update workspace: %w", err)
|
||||
s.Logger.Error(ctx, "marshal resource info for successful job", slog.Error(err))
|
||||
}
|
||||
|
||||
if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
s.Logger.Info(ctx, "workspace prebuild successfully claimed by user",
|
||||
slog.F("workspace_id", workspace.ID))
|
||||
bag := audit.BaggageFromContext(ctx)
|
||||
|
||||
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
|
||||
}
|
||||
audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.WorkspaceBuild]{
|
||||
Audit: *auditor,
|
||||
Log: s.Logger,
|
||||
UserID: job.InitiatorID,
|
||||
OrganizationID: workspace.OrganizationID,
|
||||
RequestID: job.ID,
|
||||
IP: bag.IP,
|
||||
Action: auditAction,
|
||||
Old: previousBuild,
|
||||
New: workspaceBuild,
|
||||
Status: http.StatusOK,
|
||||
AdditionalFields: wriBytes,
|
||||
})
|
||||
}
|
||||
|
||||
if s.PrebuildsOrchestrator != nil && input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
// Track resource replacements, if there are any.
|
||||
orchestrator := s.PrebuildsOrchestrator.Load()
|
||||
if resourceReplacements := jobType.WorkspaceBuild.ResourceReplacements; orchestrator != nil && len(resourceReplacements) > 0 {
|
||||
// Fire and forget. Bind to the lifecycle of the server so shutdowns are handled gracefully.
|
||||
go (*orchestrator).TrackResourceReplacement(s.lifecycleCtx, workspace.ID, workspaceBuild.ID, resourceReplacements)
|
||||
}
|
||||
case *proto.CompletedJob_TemplateDryRun_:
|
||||
}
|
||||
|
||||
msg, err := json.Marshal(wspubsub.WorkspaceEvent{
|
||||
Kind: wspubsub.WorkspaceEventKindStateChange,
|
||||
WorkspaceID: workspace.ID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal workspace update event: %s", err)
|
||||
}
|
||||
err = s.Pubsub.Publish(wspubsub.WorkspaceEventChannel(workspace.OwnerID), msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update workspace: %w", err)
|
||||
}
|
||||
|
||||
if input.PrebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
|
||||
s.Logger.Info(ctx, "workspace prebuild successfully claimed by user",
|
||||
slog.F("workspace_id", workspace.ID))
|
||||
|
||||
err = prebuilds.NewPubsubWorkspaceClaimPublisher(s.Pubsub).PublishWorkspaceClaim(agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: workspace.ID,
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to publish workspace claim event", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// completeTemplateDryRunJob handles completion of a template dry-run job.
|
||||
// All database operations are performed within a transaction.
|
||||
func (s *server) completeTemplateDryRunJob(ctx context.Context, job database.ProvisionerJob, jobID uuid.UUID, jobType *proto.CompletedJob_TemplateDryRun_, telemetrySnapshot *telemetry.Snapshot) error {
|
||||
// Execute all database operations in a transaction
|
||||
return s.Database.InTx(func(db database.Store) error {
|
||||
now := s.timeNow()
|
||||
|
||||
// Process resources
|
||||
for _, resource := range jobType.TemplateDryRun.Resources {
|
||||
s.Logger.Info(ctx, "inserting template dry-run job resource",
|
||||
slog.F("job_id", job.ID.String()),
|
||||
slog.F("resource_name", resource.Name),
|
||||
slog.F("resource_type", resource.Type))
|
||||
|
||||
err = InsertWorkspaceResource(ctx, s.Database, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
|
||||
err := InsertWorkspaceResource(ctx, db, jobID, database.WorkspaceTransitionStart, resource, telemetrySnapshot)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert resource: %w", err)
|
||||
return xerrors.Errorf("insert resource: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Process modules
|
||||
for _, module := range jobType.TemplateDryRun.Modules {
|
||||
s.Logger.Info(ctx, "inserting template dry-run job module",
|
||||
slog.F("job_id", job.ID.String()),
|
||||
slog.F("module_source", module.Source),
|
||||
)
|
||||
|
||||
if err := InsertWorkspaceModule(ctx, s.Database, jobID, database.WorkspaceTransitionStart, module, telemetrySnapshot); err != nil {
|
||||
return nil, xerrors.Errorf("insert module: %w", err)
|
||||
if err := InsertWorkspaceModule(ctx, db, jobID, database.WorkspaceTransitionStart, module, telemetrySnapshot); err != nil {
|
||||
return xerrors.Errorf("insert module: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
// Mark job as complete
|
||||
err := db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: jobID,
|
||||
UpdatedAt: s.timeNow(),
|
||||
UpdatedAt: now,
|
||||
CompletedAt: sql.NullTime{
|
||||
Time: s.timeNow(),
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
Error: sql.NullString{},
|
||||
ErrorCode: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update provisioner job: %w", err)
|
||||
return xerrors.Errorf("update provisioner job: %w", err)
|
||||
}
|
||||
s.Logger.Debug(ctx, "marked template dry-run job as completed", slog.F("job_id", jobID))
|
||||
|
||||
default:
|
||||
if completed.Type == nil {
|
||||
return nil, xerrors.Errorf("type payload must be provided")
|
||||
}
|
||||
return nil, xerrors.Errorf("unknown job type %q; ensure coderd and provisionerd versions match",
|
||||
reflect.TypeOf(completed.Type).String())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{EndOfLogs: true})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = s.Pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
|
||||
if err != nil {
|
||||
s.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
}
|
||||
|
||||
s.Logger.Debug(ctx, "stage CompleteJob done", slog.F("job_id", jobID))
|
||||
return &proto.Empty{}, nil
|
||||
return nil
|
||||
}, nil) // End of transaction
|
||||
}
|
||||
|
||||
func (s *server) notifyWorkspaceDeleted(ctx context.Context, workspace database.Workspace, build database.WorkspaceBuild) {
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"storj.io/drpc"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
@ -1119,6 +1120,227 @@ func TestCompleteJob(t *testing.T) {
|
||||
require.ErrorContains(t, err, "you don't own this job")
|
||||
})
|
||||
|
||||
// Test for verifying transaction behavior on the extracted methods
|
||||
t.Run("TransactionBehavior", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test TemplateImport transaction
|
||||
t.Run("TemplateImportTransaction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
jobID := uuid.New()
|
||||
versionID := uuid.New()
|
||||
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
|
||||
ID: versionID,
|
||||
JobID: jobID,
|
||||
OrganizationID: pd.OrganizationID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
OrganizationID: pd.OrganizationID,
|
||||
ID: jobID,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
Type: database.ProvisionerJobTypeTemplateVersionImport,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
OrganizationID: pd.OrganizationID,
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
||||
JobId: job.ID.String(),
|
||||
Type: &proto.CompletedJob_TemplateImport_{
|
||||
TemplateImport: &proto.CompletedJob_TemplateImport{
|
||||
StartResources: []*sdkproto.Resource{{
|
||||
Name: "test-resource",
|
||||
Type: "aws_instance",
|
||||
}},
|
||||
Plan: []byte("{}"),
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify job was marked as completed
|
||||
completedJob, err := db.GetProvisionerJobByID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, completedJob.CompletedAt.Valid, "Job should be marked as completed")
|
||||
|
||||
// Verify resources were created
|
||||
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resources, 1, "Expected one resource to be created")
|
||||
require.Equal(t, "test-resource", resources[0].Name)
|
||||
})
|
||||
|
||||
// Test TemplateDryRun transaction
|
||||
t.Run("TemplateDryRunTransaction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
||||
JobId: job.ID.String(),
|
||||
Type: &proto.CompletedJob_TemplateDryRun_{
|
||||
TemplateDryRun: &proto.CompletedJob_TemplateDryRun{
|
||||
Resources: []*sdkproto.Resource{{
|
||||
Name: "test-dry-run-resource",
|
||||
Type: "aws_instance",
|
||||
}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify job was marked as completed
|
||||
completedJob, err := db.GetProvisionerJobByID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, completedJob.CompletedAt.Valid, "Job should be marked as completed")
|
||||
|
||||
// Verify resources were created
|
||||
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resources, 1, "Expected one resource to be created")
|
||||
require.Equal(t, "test-dry-run-resource", resources[0].Name)
|
||||
})
|
||||
|
||||
// Test WorkspaceBuild transaction
|
||||
t.Run("WorkspaceBuildTransaction", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, ps, pd := setup(t, false, &overrides{})
|
||||
|
||||
// Create test data
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
template := dbgen.Template(t, db, database.Template{
|
||||
Name: "template",
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
OrganizationID: pd.OrganizationID,
|
||||
})
|
||||
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
|
||||
workspaceTable := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
TemplateID: template.ID,
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: pd.OrganizationID,
|
||||
})
|
||||
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: pd.OrganizationID,
|
||||
TemplateID: uuid.NullUUID{
|
||||
UUID: template.ID,
|
||||
Valid: true,
|
||||
},
|
||||
JobID: uuid.New(),
|
||||
})
|
||||
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: workspaceTable.ID,
|
||||
TemplateVersionID: version.ID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
})
|
||||
job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
|
||||
FileID: file.ID,
|
||||
InitiatorID: user.ID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: build.ID,
|
||||
})),
|
||||
OrganizationID: pd.OrganizationID,
|
||||
})
|
||||
_, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
|
||||
OrganizationID: pd.OrganizationID,
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: pd.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a published channel to make sure the workspace event is sent
|
||||
publishedWorkspace := make(chan struct{})
|
||||
closeWorkspaceSubscribe, err := ps.SubscribeWithErr(wspubsub.WorkspaceEventChannel(workspaceTable.OwnerID),
|
||||
wspubsub.HandleWorkspaceEvent(
|
||||
func(_ context.Context, e wspubsub.WorkspaceEvent, err error) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if e.Kind == wspubsub.WorkspaceEventKindStateChange && e.WorkspaceID == workspaceTable.ID {
|
||||
close(publishedWorkspace)
|
||||
}
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
defer closeWorkspaceSubscribe()
|
||||
|
||||
// The actual test
|
||||
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{
|
||||
JobId: job.ID.String(),
|
||||
Type: &proto.CompletedJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.CompletedJob_WorkspaceBuild{
|
||||
State: []byte{},
|
||||
Resources: []*sdkproto.Resource{{
|
||||
Name: "test-workspace-resource",
|
||||
Type: "aws_instance",
|
||||
}},
|
||||
Timings: []*sdkproto.Timing{{
|
||||
Stage: "test",
|
||||
Source: "test-source",
|
||||
Resource: "test-resource",
|
||||
Action: "test-action",
|
||||
Start: timestamppb.Now(),
|
||||
End: timestamppb.Now(),
|
||||
}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for workspace notification
|
||||
select {
|
||||
case <-publishedWorkspace:
|
||||
// Success
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("Workspace event not published")
|
||||
}
|
||||
|
||||
// Verify job was marked as completed
|
||||
completedJob, err := db.GetProvisionerJobByID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, completedJob.CompletedAt.Valid, "Job should be marked as completed")
|
||||
|
||||
// Verify resources were created
|
||||
resources, err := db.GetWorkspaceResourcesByJobID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resources, 1, "Expected one resource to be created")
|
||||
require.Equal(t, "test-workspace-resource", resources[0].Name)
|
||||
|
||||
// Verify timings were recorded
|
||||
timings, err := db.GetProvisionerJobTimingsByJobID(ctx, job.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, timings, 1, "Expected one timing entry to be created")
|
||||
require.Equal(t, "test", string(timings[0].Stage), "Timing stage should match what was sent")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("TemplateImport_MissingGitAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv, db, _, pd := setup(t, false, &overrides{})
|
||||
|
Reference in New Issue
Block a user