feat: add provisioner job hang detector (#7927)

This commit is contained in:
Dean Sheather
2023-06-25 23:17:00 +10:00
committed by GitHub
parent 3671846b1b
commit 98a5ae7f48
28 changed files with 1414 additions and 54 deletions

3
coderd/apidoc/docs.go generated
View File

@ -7239,6 +7239,9 @@ const docTemplate = `{
"in_memory_database": {
"type": "boolean"
},
"job_hang_detector_interval": {
"type": "integer"
},
"logging": {
"$ref": "#/definitions/codersdk.LoggingConfig"
},

View File

@ -6468,6 +6468,9 @@
"in_memory_database": {
"type": "boolean"
},
"job_hang_detector_interval": {
"type": "integer"
},
"logging": {
"$ref": "#/definitions/codersdk.LoggingConfig"
},

View File

@ -68,6 +68,7 @@ import (
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/schedule"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/unhanger"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/coderd/workspaceapps"
@ -256,6 +257,12 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
).WithStatsChannel(options.AutobuildStats)
lifecycleExecutor.Run()
hangDetectorTicker := time.NewTicker(options.DeploymentValues.JobHangDetectorInterval.Value())
defer hangDetectorTicker.Stop()
hangDetector := unhanger.New(ctx, options.Database, options.Pubsub, slogtest.Make(t, nil).Named("unhanger.detector"), hangDetectorTicker.C)
hangDetector.Start()
t.Cleanup(hangDetector.Close)
var mutex sync.RWMutex
var handler http.Handler
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -18,13 +18,6 @@ import (
"golang.org/x/xerrors"
)
// Well-known lock IDs for lock functions in the database. These should not
// change. If locks are deprecated, they should be kept to avoid reusing the
// same ID.
const (
LockIDDeploymentSetup = iota + 1
)
// Store contains all queryable database functions.
// It extends the generated interface to add transaction support.
type Store interface {

View File

@ -3,7 +3,6 @@ package db2sdk
import (
"encoding/json"
"time"
"github.com/google/uuid"
@ -81,6 +80,9 @@ func TemplateVersionParameter(param database.TemplateVersionParameter) (codersdk
}
func ProvisionerJobStatus(provisionerJob database.ProvisionerJob) codersdk.ProvisionerJobStatus {
// The case where jobs are hung is handled by the unhang package. We can't
// just return Failed here when it's hung because that doesn't reflect in
// the database.
switch {
case provisionerJob.CanceledAt.Valid:
if !provisionerJob.CompletedAt.Valid {
@ -97,8 +99,6 @@ func ProvisionerJobStatus(provisionerJob database.ProvisionerJob) codersdk.Provi
return codersdk.ProvisionerJobSucceeded
}
return codersdk.ProvisionerJobFailed
case database.Now().Sub(provisionerJob.UpdatedAt) > 30*time.Second:
return codersdk.ProvisionerJobFailed
default:
return codersdk.ProvisionerJobRunning
}

View File

@ -96,17 +96,6 @@ func TestProvisionerJobStatus(t *testing.T) {
},
status: codersdk.ProvisionerJobFailed,
},
{
name: "not_updated",
job: database.ProvisionerJob{
StartedAt: sql.NullTime{
Time: database.Now().Add(-time.Minute),
Valid: true,
},
UpdatedAt: database.Now().Add(-31 * time.Second),
},
status: codersdk.ProvisionerJobFailed,
},
{
name: "updated",
job: database.ProvisionerJob{

View File

@ -176,6 +176,25 @@ var (
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
// See unhanger package.
subjectHangDetector = rbac.Subject{
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Name: "hangdetector",
DisplayName: "Hang Detector Daemon",
Site: rbac.Permissions(map[string][]rbac.Action{
rbac.ResourceSystem.Type: {rbac.WildcardSymbol},
rbac.ResourceTemplate.Type: {rbac.ActionRead},
rbac.ResourceWorkspace.Type: {rbac.ActionRead, rbac.ActionUpdate},
}),
Org: map[string][]rbac.Permission{},
User: []rbac.Permission{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectSystemRestricted = rbac.Subject{
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
@ -217,6 +236,12 @@ func AsAutostart(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectAutostart)
}
// AsHangDetector returns a context with an actor that has permissions required
// for unhanger.Detector to function.
func AsHangDetector(ctx context.Context) context.Context {
return context.WithValue(ctx, authContextKey{}, subjectHangDetector)
}
// AsSystemRestricted returns a context with an actor that has permissions
// required for various system operations (login, logout, metrics cache).
func AsSystemRestricted(ctx context.Context) context.Context {
@ -950,6 +975,14 @@ func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID
return fetchWithPostFilter(q.auth, q.db.GetGroupsByOrganizationID)(ctx, organizationID)
}
// TODO: We need to create a ProvisionerJob resource type
func (q *querier) GetHungProvisionerJobs(ctx context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) {
// if err := q.authorizeContext(ctx, rbac.ActionCreate, rbac.ResourceSystem); err != nil {
// return nil, err
// }
return q.db.GetHungProvisionerJobs(ctx, hungSince)
}
func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
return "", err

View File

@ -1753,6 +1753,19 @@ func (q *fakeQuerier) GetGroupsByOrganizationID(_ context.Context, organizationI
return groups, nil
}
func (q *fakeQuerier) GetHungProvisionerJobs(_ context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
hungJobs := []database.ProvisionerJob{}
for _, provisionerJob := range q.provisionerJobs {
if provisionerJob.StartedAt.Valid && !provisionerJob.CompletedAt.Valid && provisionerJob.UpdatedAt.Before(hungSince) {
hungJobs = append(hungJobs, provisionerJob)
}
}
return hungJobs, nil
}
func (q *fakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -2135,7 +2148,7 @@ func (q *fakeQuerier) GetProvisionerLogsAfterID(_ context.Context, arg database.
if jobLog.JobID != arg.JobID {
continue
}
if arg.CreatedAfter != 0 && jobLog.ID < arg.CreatedAfter {
if jobLog.ID <= arg.CreatedAfter {
continue
}
logs = append(logs, jobLog)

View File

@ -399,6 +399,13 @@ func (m metricsStore) GetGroupsByOrganizationID(ctx context.Context, organizatio
return groups, err
}
func (m metricsStore) GetHungProvisionerJobs(ctx context.Context, hungSince time.Time) ([]database.ProvisionerJob, error) {
start := time.Now()
jobs, err := m.s.GetHungProvisionerJobs(ctx, hungSince)
m.queryLatencies.WithLabelValues("GetHungProvisionerJobs").Observe(time.Since(start).Seconds())
return jobs, err
}
func (m metricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) {
start := time.Now()
version, err := m.s.GetLastUpdateCheck(ctx)

View File

@ -701,6 +701,21 @@ func (mr *MockStoreMockRecorder) GetGroupsByOrganizationID(arg0, arg1 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsByOrganizationID", reflect.TypeOf((*MockStore)(nil).GetGroupsByOrganizationID), arg0, arg1)
}
// GetHungProvisionerJobs mocks base method.
func (m *MockStore) GetHungProvisionerJobs(arg0 context.Context, arg1 time.Time) ([]database.ProvisionerJob, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetHungProvisionerJobs", arg0, arg1)
ret0, _ := ret[0].([]database.ProvisionerJob)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetHungProvisionerJobs indicates an expected call of GetHungProvisionerJobs.
func (mr *MockStoreMockRecorder) GetHungProvisionerJobs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHungProvisionerJobs", reflect.TypeOf((*MockStore)(nil).GetHungProvisionerJobs), arg0, arg1)
}
// GetLastUpdateCheck mocks base method.
func (m *MockStore) GetLastUpdateCheck(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()

19
coderd/database/lock.go Normal file
View File

@ -0,0 +1,19 @@
package database
import "hash/fnv"
// Well-known lock IDs for lock functions in the database. These should not
// change. If locks are deprecated, they should be kept in this list to avoid
// reusing the same ID.
const (
// Keep the unused iota here so we don't need + 1 every time
lockIDUnused = iota
LockIDDeploymentSetup
)
// GenLockID generates a unique and consistent lock ID from a given string.
func GenLockID(name string) int64 {
hash := fnv.New64()
_, _ = hash.Write([]byte(name))
return int64(hash.Sum64())
}

View File

@ -16,8 +16,6 @@ type sqlcQuerier interface {
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error
// Acquires the lock for a single job that isn't started, completed,
// canceled, and that matches an array of provisioner types.
@ -75,6 +73,7 @@ type sqlcQuerier interface {
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error)
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error)
GetLastUpdateCheck(ctx context.Context) (string, error)
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
@ -217,8 +216,6 @@ type sqlcQuerier interface {
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateGitAuthLink(ctx context.Context, arg UpdateGitAuthLinkParams) (GitAuthLink, error)

View File

@ -1527,8 +1527,6 @@ SELECT pg_advisory_xact_lock($1)
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
func (q *sqlQuerier) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error {
_, err := q.db.ExecContext(ctx, acquireLock, pgAdvisoryXactLock)
return err
@ -1542,8 +1540,6 @@ SELECT pg_try_advisory_xact_lock($1)
//
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
//
// Use database.LockID() to generate a unique lock ID from a string.
func (q *sqlQuerier) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
row := q.db.QueryRowContext(ctx, tryAcquireLock, pgTryAdvisoryXactLock)
var pg_try_advisory_xact_lock bool
@ -2201,6 +2197,59 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi
return i, err
}
const getHungProvisionerJobs = `-- name: GetHungProvisionerJobs :many
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags, error_code, trace_metadata
FROM
provisioner_jobs
WHERE
updated_at < $1
AND started_at IS NOT NULL
AND completed_at IS NULL
`
func (q *sqlQuerier) GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error) {
rows, err := q.db.QueryContext(ctx, getHungProvisionerJobs, updatedAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ProvisionerJob
for rows.Next() {
var i ProvisionerJob
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.StartedAt,
&i.CanceledAt,
&i.CompletedAt,
&i.Error,
&i.OrganizationID,
&i.InitiatorID,
&i.Provisioner,
&i.StorageMethod,
&i.Type,
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
&i.ErrorCode,
&i.TraceMetadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getProvisionerJobByID = `-- name: GetProvisionerJobByID :one
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags, error_code, trace_metadata

View File

@ -3,8 +3,6 @@
--
-- This must be called from within a transaction. The lock will be automatically
-- released when the transaction ends.
--
-- Use database.LockID() to generate a unique lock ID from a string.
SELECT pg_advisory_xact_lock($1);
-- name: TryAcquireLock :one
@ -12,6 +10,4 @@ SELECT pg_advisory_xact_lock($1);
--
-- This must be called from within a transaction. The lock will be automatically
-- released when the transaction ends.
--
-- Use database.LockID() to generate a unique lock ID from a string.
SELECT pg_try_advisory_xact_lock($1);

View File

@ -128,3 +128,13 @@ SET
error_code = $5
WHERE
id = $1;
-- name: GetHungProvisionerJobs :many
SELECT
*
FROM
provisioner_jobs
WHERE
updated_at < $1
AND started_at IS NOT NULL
AND completed_at IS NULL;

363
coderd/unhanger/detector.go Normal file
View File

@ -0,0 +1,363 @@
package unhanger
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math/rand" //#nosec // this is only used for shuffling an array to pick random jobs to unhang
"time"
"golang.org/x/xerrors"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/db2sdk"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
)
const (
// HungJobDuration is the duration of time since the last update to a job
// before it is considered hung.
HungJobDuration = 5 * time.Minute
// HungJobExitTimeout is the duration of time that provisioners should allow
// for a graceful exit upon cancellation due to failing to send an update to
// a job.
//
// Provisioners should avoid keeping a job "running" for longer than this
// time after failing to send an update to the job.
HungJobExitTimeout = 3 * time.Minute
// MaxJobsPerRun is the maximum number of hung jobs that the detector will
// terminate in a single run.
MaxJobsPerRun = 10
)
// HungJobLogMessages are written to provisioner job logs when a job is hung and
// terminated.
var HungJobLogMessages = []string{
"",
"====================",
"Coder: Build has been detected as hung for 5 minutes and will be terminated.",
"====================",
"",
}
// acquireLockError is returned when the detector fails to acquire a lock and
// cancels the current run.
type acquireLockError struct{}
// Error implements error.
func (acquireLockError) Error() string {
return "lock is held by another client"
}
// jobInelligibleError is returned when a job is not eligible to be terminated
// anymore.
type jobInelligibleError struct {
Err error
}
// Error implements error.
func (e jobInelligibleError) Error() string {
return fmt.Sprintf("job is no longer eligible to be terminated: %s", e.Err)
}
// Detector automatically detects hung provisioner jobs, sends messages into the
// build log and terminates them as failed.
type Detector struct {
ctx context.Context
cancel context.CancelFunc
done chan struct{}
db database.Store
pubsub pubsub.Pubsub
log slog.Logger
tick <-chan time.Time
stats chan<- Stats
}
// Stats contains statistics about the last run of the detector.
type Stats struct {
// TerminatedJobIDs contains the IDs of all jobs that were detected as hung and
// terminated.
TerminatedJobIDs []uuid.UUID
// Error is the fatal error that occurred during the last run of the
// detector, if any. Error may be set to AcquireLockError if the detector
// failed to acquire a lock.
Error error
}
// New returns a new hang detector.
func New(ctx context.Context, db database.Store, pub pubsub.Pubsub, log slog.Logger, tick <-chan time.Time) *Detector {
//nolint:gocritic // Hang detector has a limited set of permissions.
ctx, cancel := context.WithCancel(dbauthz.AsHangDetector(ctx))
d := &Detector{
ctx: ctx,
cancel: cancel,
done: make(chan struct{}),
db: db,
pubsub: pub,
log: log,
tick: tick,
stats: nil,
}
return d
}
// WithStatsChannel will cause Executor to push a RunStats to ch after
// every tick. This push is blocking, so if ch is not read, the detector will
// hang. This should only be used in tests.
func (d *Detector) WithStatsChannel(ch chan<- Stats) *Detector {
d.stats = ch
return d
}
// Start will cause the detector to detect and unhang provisioner jobs on every
// tick from its channel. It will stop when its context is Done, or when its
// channel is closed.
//
// Start should only be called once.
func (d *Detector) Start() {
go func() {
defer close(d.done)
defer d.cancel()
for {
select {
case <-d.ctx.Done():
return
case t, ok := <-d.tick:
if !ok {
return
}
stats := d.run(t)
if stats.Error != nil && !xerrors.As(stats.Error, &acquireLockError{}) {
d.log.Warn(d.ctx, "error running workspace build hang detector once", slog.Error(stats.Error))
}
if len(stats.TerminatedJobIDs) != 0 {
d.log.Warn(d.ctx, "detected (and terminated) hung provisioner jobs", slog.F("job_ids", stats.TerminatedJobIDs))
}
if d.stats != nil {
select {
case <-d.ctx.Done():
return
case d.stats <- stats:
}
}
}
}
}()
}
// Wait will block until the detector is stopped.
func (d *Detector) Wait() {
<-d.done
}
// Close will stop the detector.
func (d *Detector) Close() {
d.cancel()
<-d.done
}
func (d *Detector) run(t time.Time) Stats {
ctx, cancel := context.WithTimeout(d.ctx, 5*time.Minute)
defer cancel()
stats := Stats{
TerminatedJobIDs: []uuid.UUID{},
Error: nil,
}
// Find all provisioner jobs that are currently running but have not
// received an update in the last 5 minutes.
jobs, err := d.db.GetHungProvisionerJobs(ctx, t.Add(-HungJobDuration))
if err != nil {
stats.Error = xerrors.Errorf("get hung provisioner jobs: %w", err)
return stats
}
// Limit the number of jobs we'll unhang in a single run to avoid
// timing out.
if len(jobs) > MaxJobsPerRun {
// Pick a random subset of the jobs to unhang.
rand.Shuffle(len(jobs), func(i, j int) {
jobs[i], jobs[j] = jobs[j], jobs[i]
})
jobs = jobs[:MaxJobsPerRun]
}
// Send a message into the build log for each hung job saying that it
// has been detected and will be terminated, then mark the job as
// failed.
for _, job := range jobs {
log := d.log.With(slog.F("job_id", job.ID))
err := unhangJob(ctx, log, d.db, d.pubsub, job.ID)
if err != nil && !(xerrors.As(err, &acquireLockError{}) || xerrors.As(err, &jobInelligibleError{})) {
log.Error(ctx, "error forcefully terminating hung provisioner job", slog.Error(err))
continue
}
stats.TerminatedJobIDs = append(stats.TerminatedJobIDs, job.ID)
}
return stats
}
func unhangJob(ctx context.Context, log slog.Logger, db database.Store, pub pubsub.Pubsub, jobID uuid.UUID) error {
var lowestLogID int64
err := db.InTx(func(db database.Store) error {
locked, err := db.TryAcquireLock(ctx, database.GenLockID(fmt.Sprintf("hang-detector:%s", jobID)))
if err != nil {
return xerrors.Errorf("acquire lock: %w", err)
}
if !locked {
// This error is ignored.
return acquireLockError{}
}
// Refetch the job while we hold the lock.
job, err := db.GetProvisionerJobByID(ctx, jobID)
if err != nil {
return xerrors.Errorf("get provisioner job: %w", err)
}
// Check if we should still unhang it.
jobStatus := db2sdk.ProvisionerJobStatus(job)
if jobStatus != codersdk.ProvisionerJobRunning {
return jobInelligibleError{
Err: xerrors.Errorf("job is not running (status %s)", jobStatus),
}
}
if job.UpdatedAt.After(time.Now().Add(-HungJobDuration)) {
return jobInelligibleError{
Err: xerrors.New("job has been updated recently"),
}
}
log.Info(ctx, "detected hung (>5m) provisioner job, forcefully terminating")
// First, get the latest logs from the build so we can make sure
// our messages are in the latest stage.
logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: job.ID,
CreatedAfter: 0,
})
if err != nil {
return xerrors.Errorf("get logs for hung job: %w", err)
}
logStage := ""
if len(logs) != 0 {
logStage = logs[len(logs)-1].Stage
}
if logStage == "" {
logStage = "Unknown"
}
// Insert the messages into the build log.
insertParams := database.InsertProvisionerJobLogsParams{
JobID: job.ID,
}
now := database.Now()
for i, msg := range HungJobLogMessages {
// Set the created at in a way that ensures each message has
// a unique timestamp so they will be sorted correctly.
insertParams.CreatedAt = append(insertParams.CreatedAt, now.Add(time.Millisecond*time.Duration(i)))
insertParams.Level = append(insertParams.Level, database.LogLevelError)
insertParams.Stage = append(insertParams.Stage, logStage)
insertParams.Source = append(insertParams.Source, database.LogSourceProvisionerDaemon)
insertParams.Output = append(insertParams.Output, msg)
}
newLogs, err := db.InsertProvisionerJobLogs(ctx, insertParams)
if err != nil {
return xerrors.Errorf("insert logs for hung job: %w", err)
}
lowestLogID = newLogs[0].ID
// Mark the job as failed.
now = database.Now()
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: job.ID,
UpdatedAt: now,
CompletedAt: sql.NullTime{
Time: now,
Valid: true,
},
Error: sql.NullString{
String: "Coder: Build has been detected as hung for 5 minutes and has been terminated by hang detector.",
Valid: true,
},
ErrorCode: sql.NullString{
Valid: false,
},
})
if err != nil {
return xerrors.Errorf("mark job as failed: %w", err)
}
// If the provisioner job is a workspace build, copy the
// provisioner state from the previous build to this workspace
// build.
if job.Type == database.ProvisionerJobTypeWorkspaceBuild {
build, err := db.GetWorkspaceBuildByJobID(ctx, job.ID)
if err != nil {
return xerrors.Errorf("get workspace build for workspace build job by job id: %w", err)
}
// Only copy the provisioner state if there's no state in
// the current build.
if len(build.ProvisionerState) == 0 {
// Get the previous build if it exists.
prevBuild, err := db.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: build.WorkspaceID,
BuildNumber: build.BuildNumber - 1,
})
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get previous workspace build: %w", err)
}
if err == nil {
_, err = db.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{
ID: build.ID,
UpdatedAt: database.Now(),
ProvisionerState: prevBuild.ProvisionerState,
Deadline: time.Time{},
MaxDeadline: time.Time{},
})
if err != nil {
return xerrors.Errorf("update workspace build by id: %w", err)
}
}
}
}
return nil
}, nil)
if err != nil {
return xerrors.Errorf("in tx: %w", err)
}
// Publish the new log notification to pubsub. Use the lowest log ID
// inserted so the log stream will fetch everything after that point.
data, err := json.Marshal(provisionersdk.ProvisionerJobLogsNotifyMessage{
CreatedAfter: lowestLogID - 1,
EndOfLogs: true,
})
if err != nil {
return xerrors.Errorf("marshal log notification: %w", err)
}
err = pub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(jobID), data)
if err != nil {
return xerrors.Errorf("publish log notification: %w", err)
}
return nil
}

View File

@ -0,0 +1,724 @@
package unhanger_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/unhanger"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDetectorNoJobs(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- time.Now()
stats := <-statsCh
require.NoError(t, stats.Error)
require.Empty(t, stats.TerminatedJobIDs)
detector.Close()
detector.Wait()
}
func TestDetectorNoHungJobs(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
// Insert some jobs that are running and haven't been updated in a while,
// but not enough to be considered hung.
now := time.Now()
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
file := dbgen.File(t, db, database.File{})
for i := 0; i < 5; i++ {
dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: now.Add(-time.Minute * 5),
UpdatedAt: now.Add(-time.Minute * time.Duration(i)),
StartedAt: sql.NullTime{
Time: now.Add(-time.Minute * 5),
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
}
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Empty(t, stats.TerminatedJobIDs)
detector.Close()
detector.Wait()
}
func TestDetectorHungWorkspaceBuild(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
twentyMinAgo = now.Add(-time.Minute * 20)
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
CreatedBy: user.ID,
})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: template.ID,
})
// Previous build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: twentyMinAgo,
UpdatedAt: twentyMinAgo,
StartedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
CompletedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 1,
ProvisionerState: expectedWorkspaceBuildState,
JobID: previousWorkspaceBuildJob.ID,
})
// Current build.
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 2,
JobID: currentWorkspaceBuildJob.ID,
// No provisioner state.
})
)
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0])
// Check that the current provisioner job was updated.
job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the provisioner state was copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
}
func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
twentyMinAgo = now.Add(-time.Minute * 20)
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
CreatedBy: user.ID,
})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: template.ID,
})
// Previous build.
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: twentyMinAgo,
UpdatedAt: twentyMinAgo,
StartedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
CompletedAt: sql.NullTime{
Time: twentyMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 1,
ProvisionerState: []byte(`{"dean":"NOT cool","colin":"also NOT cool"}`),
JobID: previousWorkspaceBuildJob.ID,
})
// Current build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 2,
JobID: currentWorkspaceBuildJob.ID,
// Should not be overridden.
ProvisionerState: expectedWorkspaceBuildState,
})
)
t.Log("previous job ID: ", previousWorkspaceBuildJob.ID)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0])
// Check that the current provisioner job was updated.
job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the provisioner state was NOT copied.
build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
}
func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
template = dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
templateVersion = dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
CreatedBy: user.ID,
})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: template.ID,
})
// First build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: []byte("{}"),
})
currentWorkspaceBuild = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
TemplateVersionID: templateVersion.ID,
BuildNumber: 1,
JobID: currentWorkspaceBuildJob.ID,
// Should not be overridden.
ProvisionerState: expectedWorkspaceBuildState,
})
)
t.Log("current job ID: ", currentWorkspaceBuildJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Equal(t, currentWorkspaceBuildJob.ID, stats.TerminatedJobIDs[0])
// Check that the current provisioner job was updated.
job, err := db.GetProvisionerJobByID(ctx, currentWorkspaceBuildJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the provisioner state was NOT updated.
build, err := db.GetWorkspaceBuildByID(ctx, currentWorkspaceBuild.ID)
require.NoError(t, err)
require.Equal(t, expectedWorkspaceBuildState, build.ProvisionerState)
detector.Close()
detector.Wait()
}
func TestDetectorHungOtherJobTypes(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
// Template import job.
templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
// Template dry-run job.
templateDryRunJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: []byte("{}"),
})
)
t.Log("template import job ID: ", templateImportJob.ID)
t.Log("template dry-run job ID: ", templateDryRunJob.ID)
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 2)
require.Contains(t, stats.TerminatedJobIDs, templateImportJob.ID)
require.Contains(t, stats.TerminatedJobIDs, templateDryRunJob.ID)
// Check that the template import job was updated.
job, err := db.GetProvisionerJobByID(ctx, templateImportJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
// Check that the template dry-run job was updated.
job, err = db.GetProvisionerJobByID(ctx, templateDryRunJob.ID)
require.NoError(t, err)
require.WithinDuration(t, now, job.UpdatedAt, 30*time.Second)
require.True(t, job.CompletedAt.Valid)
require.WithinDuration(t, now, job.CompletedAt.Time, 30*time.Second)
require.True(t, job.Error.Valid)
require.Contains(t, job.Error.String, "Build has been detected as hung")
require.False(t, job.ErrorCode.Valid)
detector.Close()
detector.Wait()
}
func TestDetectorPushesLogs(t *testing.T) {
t.Parallel()
cases := []struct {
name string
preLogCount int
preLogStage string
expectStage string
}{
{
name: "WithExistingLogs",
preLogCount: 10,
preLogStage: "Stage Name",
expectStage: "Stage Name",
},
{
name: "WithExistingLogsNoStage",
preLogCount: 10,
preLogStage: "",
expectStage: "Unknown",
},
{
name: "WithoutExistingLogs",
preLogCount: 0,
expectStage: "Unknown",
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
)
var (
now = time.Now()
tenMinAgo = now.Add(-time.Minute * 10)
sixMinAgo = now.Add(-time.Minute * 6)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
// Template import job.
templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
Time: tenMinAgo,
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
)
t.Log("template import job ID: ", templateImportJob.ID)
// Insert some logs at the start of the job.
if c.preLogCount > 0 {
insertParams := database.InsertProvisionerJobLogsParams{
JobID: templateImportJob.ID,
}
for i := 0; i < c.preLogCount; i++ {
insertParams.CreatedAt = append(insertParams.CreatedAt, tenMinAgo.Add(time.Millisecond*time.Duration(i)))
insertParams.Level = append(insertParams.Level, database.LogLevelInfo)
insertParams.Stage = append(insertParams.Stage, c.preLogStage)
insertParams.Source = append(insertParams.Source, database.LogSourceProvisioner)
insertParams.Output = append(insertParams.Output, fmt.Sprintf("Output %d", i))
}
logs, err := db.InsertProvisionerJobLogs(ctx, insertParams)
require.NoError(t, err)
require.Len(t, logs, 10)
}
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
// Create pubsub subscription to listen for new log events.
pubsubCalled := make(chan int64, 1)
pubsubCancel, err := pubsub.Subscribe(provisionersdk.ProvisionerJobLogsNotifyChannel(templateImportJob.ID), func(ctx context.Context, message []byte) {
defer close(pubsubCalled)
var event provisionersdk.ProvisionerJobLogsNotifyMessage
err := json.Unmarshal(message, &event)
if !assert.NoError(t, err) {
return
}
assert.True(t, event.EndOfLogs)
pubsubCalled <- event.CreatedAfter
})
require.NoError(t, err)
defer pubsubCancel()
tickCh <- now
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
require.Contains(t, stats.TerminatedJobIDs, templateImportJob.ID)
after := <-pubsubCalled
// Get the jobs after the given time and check that they are what we
// expect.
logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: templateImportJob.ID,
CreatedAfter: after,
})
require.NoError(t, err)
require.Len(t, logs, len(unhanger.HungJobLogMessages))
for i, log := range logs {
assert.Equal(t, database.LogLevelError, log.Level)
assert.Equal(t, c.expectStage, log.Stage)
assert.Equal(t, database.LogSourceProvisionerDaemon, log.Source)
assert.Equal(t, unhanger.HungJobLogMessages[i], log.Output)
}
// Double check the full log count.
logs, err = db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: templateImportJob.ID,
CreatedAfter: 0,
})
require.NoError(t, err)
require.Len(t, logs, c.preLogCount+len(unhanger.HungJobLogMessages))
detector.Close()
detector.Wait()
})
}
}
func TestDetectorMaxJobsPerRun(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitLong)
db, pubsub = dbtestutil.NewDB(t)
log = slogtest.Make(t, nil)
tickCh = make(chan time.Time)
statsCh = make(chan unhanger.Stats)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{})
file = dbgen.File(t, db, database.File{})
)
// Create unhanger.MaxJobsPerRun + 1 hung jobs.
now := time.Now()
for i := 0; i < unhanger.MaxJobsPerRun+1; i++ {
dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
StartedAt: sql.NullTime{
Time: now.Add(-time.Hour),
Valid: true,
},
OrganizationID: org.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte("{}"),
})
}
detector := unhanger.New(ctx, db, pubsub, log, tickCh).WithStatsChannel(statsCh)
detector.Start()
tickCh <- now
// Make sure that only unhanger.MaxJobsPerRun jobs are terminated.
stats := <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, unhanger.MaxJobsPerRun)
// Run the detector again and make sure that only the remaining job is
// terminated.
tickCh <- now
stats = <-statsCh
require.NoError(t, stats.Error)
require.Len(t, stats.TerminatedJobIDs, 1)
detector.Close()
detector.Wait()
}