feat: integrate Acquirer for provisioner jobs (#9717)

* chore: add Acquirer to provisionerdserver pkg

Signed-off-by: Spike Curtis <spike@coder.com>

* code review improvements & fixes

Signed-off-by: Spike Curtis <spike@coder.com>

* feat: integrate Acquirer for provisioner jobs

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix imports, whitespace

Signed-off-by: Spike Curtis <spike@coder.com>

* provisionerdserver always closes; remove poll interval from playwright

Signed-off-by: Spike Curtis <spike@coder.com>

* post jobs outside transactions

Signed-off-by: Spike Curtis <spike@coder.com>

* graceful shutdown in test

Signed-off-by: Spike Curtis <spike@coder.com>

* Mark AcquireJob deprecated

Signed-off-by: Spike Curtis <spike@coder.com>

* Graceful shutdown on all provisionerd tests

Signed-off-by: Spike Curtis <spike@coder.com>

* Deprecate, not remove CLI flags

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis
2023-09-19 10:25:57 +04:00
committed by GitHub
parent 6cf531bfef
commit 375c70d141
41 changed files with 1497 additions and 1096 deletions

View File

@ -134,7 +134,7 @@ func Test_ActivityBumpWorkspace(t *testing.T) {
TemplateID: template.ID,
Ttl: sql.NullInt64{Valid: true, Int64: int64(tt.workspaceTTL)},
})
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
CompletedAt: tt.jobCompletedAt,
})
@ -225,7 +225,7 @@ func Test_ActivityBumpWorkspace(t *testing.T) {
func insertPrevWorkspaceBuild(t *testing.T, db database.Store, orgID, tvID, workspaceID uuid.UUID, transition database.WorkspaceTransition, buildNumber int32) {
t.Helper()
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: orgID,
})
_ = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{

View File

@ -16,6 +16,8 @@ import (
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/schedule"
"github.com/coder/coder/v2/coderd/schedule/cron"
"github.com/coder/coder/v2/coderd/wsbuilder"
@ -26,6 +28,7 @@ import (
type Executor struct {
ctx context.Context
db database.Store
ps pubsub.Pubsub
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
log slog.Logger
tick <-chan time.Time
@ -40,11 +43,12 @@ type Stats struct {
}
// New returns a new wsactions executor.
func NewExecutor(ctx context.Context, db database.Store, tss *atomic.Pointer[schedule.TemplateScheduleStore], log slog.Logger, tick <-chan time.Time) *Executor {
func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, tss *atomic.Pointer[schedule.TemplateScheduleStore], log slog.Logger, tick <-chan time.Time) *Executor {
le := &Executor{
//nolint:gocritic // Autostart has a limited set of permissions.
ctx: dbauthz.AsAutostart(ctx),
db: db,
ps: ps,
templateScheduleStore: tss,
tick: tick,
log: log.Named("autobuild"),
@ -129,6 +133,7 @@ func (e *Executor) runOnce(t time.Time) Stats {
log := e.log.With(slog.F("workspace_id", wsID))
eg.Go(func() error {
var job *database.ProvisionerJob
err := e.db.InTx(func(tx database.Store) error {
// Re-check eligibility since the first check was outside the
// transaction and the workspace settings may have changed.
@ -168,7 +173,8 @@ func (e *Executor) runOnce(t time.Time) Stats {
SetLastWorkspaceBuildJobInTx(&latestJob).
Reason(reason)
if _, _, err := builder.Build(e.ctx, tx, nil); err != nil {
_, job, err = builder.Build(e.ctx, tx, nil)
if err != nil {
log.Error(e.ctx, "unable to transition workspace",
slog.F("transition", nextTransition),
slog.Error(err),
@ -230,6 +236,17 @@ func (e *Executor) runOnce(t time.Time) Stats {
if err != nil {
log.Error(e.ctx, "workspace scheduling failed", slog.Error(err))
}
if job != nil && err == nil {
// Note that we can't refactor such that posting the job happens inside wsbuilder because it's called
// with an outer transaction like this, and we need to make sure the outer transaction commits before
// posting the job. If we post before the transaction commits, provisionerd might try to acquire the
// job, fail, and then sit idle instead of picking up the job.
err = provisionerjobs.PostJob(e.ps, *job)
if err != nil {
// Client probably doesn't care about this error, so just log it.
log.Error(e.ctx, "failed to post provisioner job to pubsub", slog.Error(err))
}
}
return nil
})
}

View File

@ -14,6 +14,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/cryptorand"
@ -26,11 +27,11 @@ func TestBatchStats(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, _ := dbtestutil.NewDB(t)
store, ps := dbtestutil.NewDB(t)
// Set up some test dependencies.
deps1 := setupDeps(t, store)
deps2 := setupDeps(t, store)
deps1 := setupDeps(t, store, ps)
deps2 := setupDeps(t, store, ps)
tick := make(chan time.Time)
flushed := make(chan int, 1)
@ -168,7 +169,7 @@ type deps struct {
// It creates an organization, user, template, workspace, and agent
// along with all the other miscellaneous plumbing required to link
// them together.
func setupDeps(t *testing.T, store database.Store) deps {
func setupDeps(t *testing.T, store database.Store, ps pubsub.Pubsub) deps {
t.Helper()
org := dbgen.Organization(t, store, database.Organization{})
@ -194,7 +195,7 @@ func setupDeps(t *testing.T, store database.Store) deps {
OrganizationID: org.ID,
LastUsedAt: time.Now().Add(-time.Hour),
})
pj := dbgen.ProvisionerJob(t, store, database.ProvisionerJob{
pj := dbgen.ProvisionerJob(t, store, ps, database.ProvisionerJob{
InitiatorID: user.ID,
OrganizationID: org.ID,
})

View File

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"flag"
"fmt"
"io"
@ -366,6 +365,11 @@ func New(options *Options) *API {
UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore,
Experiments: experiments,
healthCheckGroup: &singleflight.Group[string, *healthcheck.Report]{},
Acquirer: provisionerdserver.NewAcquirer(
ctx,
options.Logger.Named("acquirer"),
options.Database,
options.Pubsub),
}
if options.UpdateCheckOptions != nil {
api.updateChecker = updatecheck.New(
@ -1016,6 +1020,8 @@ type API struct {
healthCheckCache atomic.Pointer[healthcheck.Report]
statsBatcher *batchstats.Batcher
Acquirer *provisionerdserver.Acquirer
}
// Close waits for all WebSocket connections to drain before returning.
@ -1067,7 +1073,7 @@ func compressHandler(h http.Handler) http.Handler {
// CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd.
// Useful when starting coderd and provisionerd in the same process.
func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) {
func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context) (client proto.DRPCProvisionerDaemonClient, err error) {
tracer := api.TracerProvider.Tracer(tracing.TracerName)
clientSession, serverSession := provisionersdk.MemTransportPipe()
defer func() {
@ -1077,11 +1083,8 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
}
}()
tags, err := json.Marshal(database.StringMap{
tags := provisionerdserver.Tags{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
})
if err != nil {
return nil, xerrors.Errorf("marshal tags: %w", err)
}
mux := drpcmux.New()
@ -1098,6 +1101,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
tags,
api.Database,
api.Pubsub,
api.Acquirer,
api.Telemetry,
tracer,
&api.QuotaCommitter,
@ -1105,7 +1109,6 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
api.TemplateScheduleStore,
api.UserQuietHoursScheduleStore,
api.DeploymentValues,
debounce,
provisionerdserver.Options{
OIDCConfig: api.OIDCConfig,
GitAuthConfigs: api.GitAuthConfigs,

View File

@ -266,6 +266,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
lifecycleExecutor := autobuild.NewExecutor(
ctx,
options.Database,
options.Pubsub,
&templateScheduleStore,
slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug),
options.AutobuildTicker,
@ -453,6 +454,30 @@ func NewWithAPI(t testing.TB, options *Options) (*codersdk.Client, io.Closer, *c
return client, provisionerCloser, coderAPI
}
// provisionerdCloser wraps a provisioner daemon as an io.Closer that can be called multiple times
type provisionerdCloser struct {
mu sync.Mutex
closed bool
d *provisionerd.Server
}
func (c *provisionerdCloser) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
shutdownErr := c.d.Shutdown(ctx)
closeErr := c.d.Close()
if shutdownErr != nil {
return shutdownErr
}
return closeErr
}
// NewProvisionerDaemon launches a provisionerd instance configured to work
// well with coderd testing. It registers the "echo" provisioner for
// quick testing.
@ -482,17 +507,17 @@ func NewProvisionerDaemon(t testing.TB, coderAPI *coderd.API) io.Closer {
assert.NoError(t, err)
}()
closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, 0)
daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return coderAPI.CreateInMemoryProvisionerDaemon(ctx)
}, &provisionerd.Options{
Logger: coderAPI.Logger.Named("provisionerd").Leveled(slog.LevelDebug),
JobPollInterval: 50 * time.Millisecond,
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Connector: provisionerd.LocalProvisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
},
})
closer := &provisionerdCloser{d: daemon}
t.Cleanup(func() {
_ = closer.Close()
})
@ -518,7 +543,7 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui
assert.NoError(t, err)
}()
closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
daemon := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return client.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{
Organization: org,
Provisioners: []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho},
@ -526,13 +551,13 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui
})
}, &provisionerd.Options{
Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug),
JobPollInterval: 50 * time.Millisecond,
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Connector: provisionerd.LocalProvisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(echoClient),
},
})
closer := &provisionerdCloser{d: daemon}
t.Cleanup(func() {
_ = closer.Close()
})

View File

@ -344,14 +344,14 @@ func (s *MethodTestSuite) TestGroup() {
func (s *MethodTestSuite) TestProvsionerJob() {
s.Run("Build/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) {
w := dbgen.Workspace(s.T(), db, database.Workspace{})
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
_ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID})
check.Args(j.ID).Asserts(w, rbac.ActionRead).Returns(j)
}))
s.Run("TemplateVersion/GetProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeTemplateVersionImport,
})
tpl := dbgen.Template(s.T(), db, database.Template{})
@ -366,7 +366,7 @@ func (s *MethodTestSuite) TestProvsionerJob() {
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
})
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: must(json.Marshal(struct {
TemplateVersionID uuid.UUID `json:"template_version_id"`
@ -377,7 +377,7 @@ func (s *MethodTestSuite) TestProvsionerJob() {
s.Run("Build/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) {
tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: true})
w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID})
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
_ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID})
@ -386,14 +386,14 @@ func (s *MethodTestSuite) TestProvsionerJob() {
s.Run("BuildFalseCancel/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) {
tpl := dbgen.Template(s.T(), db, database.Template{AllowUserCancelWorkspaceJobs: false})
w := dbgen.Workspace(s.T(), db, database.Workspace{TemplateID: tpl.ID})
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
_ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID})
check.Args(database.UpdateProvisionerJobWithCancelByIDParams{ID: j.ID}).Asserts(w, rbac.ActionUpdate).Returns()
}))
s.Run("TemplateVersion/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeTemplateVersionImport,
})
tpl := dbgen.Template(s.T(), db, database.Template{})
@ -405,7 +405,7 @@ func (s *MethodTestSuite) TestProvsionerJob() {
Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns()
}))
s.Run("TemplateVersionNoTemplate/UpdateProvisionerJobWithCancelByID", s.Subtest(func(db database.Store, check *expects) {
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeTemplateVersionImport,
})
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
@ -420,7 +420,7 @@ func (s *MethodTestSuite) TestProvsionerJob() {
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
})
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: must(json.Marshal(struct {
TemplateVersionID uuid.UUID `json:"template_version_id"`
@ -430,13 +430,13 @@ func (s *MethodTestSuite) TestProvsionerJob() {
Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionUpdate}).Returns()
}))
s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) {
a := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
b := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
a := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{})
b := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{})
check.Args([]uuid.UUID{a.ID, b.ID}).Asserts().Returns(slice.New(a, b))
}))
s.Run("GetProvisionerLogsAfterID", s.Subtest(func(db database.Store, check *expects) {
w := dbgen.Workspace(s.T(), db, database.Workspace{})
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
_ = dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{JobID: j.ID, WorkspaceID: w.ID})
@ -1151,20 +1151,20 @@ func (s *MethodTestSuite) TestWorkspace() {
s.Run("GetWorkspaceResourceByID", s.Subtest(func(db database.Store, check *expects) {
ws := dbgen.Workspace(s.T(), db, database.Workspace{})
build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()})
_ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
_ = dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
res := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID})
check.Args(res.ID).Asserts(ws, rbac.ActionRead).Returns(res)
}))
s.Run("Build/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) {
ws := dbgen.Workspace(s.T(), db, database.Workspace{})
build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()})
job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
job := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
check.Args(job.ID).Asserts(ws, rbac.ActionRead).Returns([]database.WorkspaceResource{})
}))
s.Run("Template/GetWorkspaceResourcesByJobID", s.Subtest(func(db database.Store, check *expects) {
tpl := dbgen.Template(s.T(), db, database.Template{})
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()})
job := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport})
job := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport})
check.Args(job.ID).Asserts(v.RBACObject(tpl), []rbac.Action{rbac.ActionRead, rbac.ActionRead}).Returns([]database.WorkspaceResource{})
}))
s.Run("InsertWorkspace", s.Subtest(func(db database.Store, check *expects) {
@ -1411,7 +1411,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}))
s.Run("GetProvisionerJobsCreatedAfter", s.Subtest(func(db database.Store, check *expects) {
// TODO: add provisioner job resource type
_ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)})
_ = dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{CreatedAt: time.Now().Add(-time.Hour)})
check.Args(time.Now()).Asserts( /*rbac.ResourceSystem, rbac.ActionRead*/ )
}))
s.Run("GetTemplateVersionsByIDs", s.Subtest(func(db database.Store, check *expects) {
@ -1450,11 +1450,11 @@ func (s *MethodTestSuite) TestSystemFunctions() {
s.Run("GetWorkspaceResourcesByJobIDs", s.Subtest(func(db database.Store, check *expects) {
tpl := dbgen.Template(s.T(), db, database.Template{})
v := dbgen.TemplateVersion(s.T(), db, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, JobID: uuid.New()})
tJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport})
tJob := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: v.JobID, Type: database.ProvisionerJobTypeTemplateVersionImport})
ws := dbgen.Workspace(s.T(), db, database.Workspace{})
build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()})
wJob := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
wJob := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
check.Args([]uuid.UUID{tJob.ID, wJob.ID}).
Asserts(rbac.ResourceSystem, rbac.ActionRead).
Returns([]database.WorkspaceResource{})
@ -1462,7 +1462,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
s.Run("GetWorkspaceResourceMetadataByResourceIDs", s.Subtest(func(db database.Store, check *expects) {
ws := dbgen.Workspace(s.T(), db, database.Workspace{})
build := dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{WorkspaceID: ws.ID, JobID: uuid.New()})
_ = dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
_ = dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{ID: build.JobID, Type: database.ProvisionerJobTypeWorkspaceBuild})
a := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID})
b := dbgen.WorkspaceResource(s.T(), db, database.WorkspaceResource{JobID: build.JobID})
check.Args([]uuid.UUID{a.ID, b.ID}).
@ -1479,8 +1479,8 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}))
s.Run("GetProvisionerJobsByIDs", s.Subtest(func(db database.Store, check *expects) {
// TODO: add a ProvisionerJob resource type
a := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
b := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
a := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{})
b := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{})
check.Args([]uuid.UUID{a.ID, b.ID}).
Asserts( /*rbac.ResourceSystem, rbac.ActionRead*/ ).
Returns(slice.New(a, b))
@ -1514,7 +1514,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}))
s.Run("AcquireProvisionerJob", s.Subtest(func(db database.Store, check *expects) {
// TODO: we need to create a ProvisionerJob resource
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
StartedAt: sql.NullTime{Valid: false},
})
check.Args(database.AcquireProvisionerJobParams{Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
@ -1522,14 +1522,14 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}))
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {
// TODO: we need to create a ProvisionerJob resource
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{})
check.Args(database.UpdateProvisionerJobWithCompleteByIDParams{
ID: j.ID,
}).Asserts( /*rbac.ResourceSystem, rbac.ActionUpdate*/ )
}))
s.Run("UpdateProvisionerJobByID", s.Subtest(func(db database.Store, check *expects) {
// TODO: we need to create a ProvisionerJob resource
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{})
check.Args(database.UpdateProvisionerJobByIDParams{
ID: j.ID,
UpdatedAt: time.Now(),
@ -1546,7 +1546,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}))
s.Run("InsertProvisionerJobLogs", s.Subtest(func(db database.Store, check *expects) {
// TODO: we need to create a ProvisionerJob resource
j := dbgen.ProvisionerJob(s.T(), db, database.ProvisionerJob{})
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{})
check.Args(database.InsertProvisionerJobLogsParams{
JobID: j.ID,
}).Asserts( /*rbac.ResourceSystem, rbac.ActionCreate*/ )

View File

@ -19,6 +19,8 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/cryptorand"
)
@ -315,8 +317,9 @@ func GroupMember(t testing.TB, db database.Store, orig database.GroupMember) dat
return member
}
// ProvisionerJob is a bit more involved to get the values such as "completedAt", "startedAt", "cancelledAt" set.
func ProvisionerJob(t testing.TB, db database.Store, orig database.ProvisionerJob) database.ProvisionerJob {
// ProvisionerJob is a bit more involved to get the values such as "completedAt", "startedAt", "cancelledAt" set. ps
// can be set to nil if you are SURE that you don't require a provisionerdaemon to acquire the job in your test.
func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig database.ProvisionerJob) database.ProvisionerJob {
id := takeFirst(orig.ID, uuid.New())
// Always set some tags to prevent Acquire from grabbing jobs it should not.
if !orig.StartedAt.Time.IsZero() {
@ -341,7 +344,10 @@ func ProvisionerJob(t testing.TB, db database.Store, orig database.ProvisionerJo
Tags: orig.Tags,
})
require.NoError(t, err, "insert job")
if ps != nil {
err = provisionerjobs.PostJob(ps, job)
require.NoError(t, err, "post job to pubsub")
}
if !orig.StartedAt.Time.IsZero() {
job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{
StartedAt: orig.StartedAt,

View File

@ -86,7 +86,7 @@ func TestGenerator(t *testing.T) {
t.Run("Job", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
exp := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{})
exp := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{})
require.Equal(t, exp, must(db.GetProvisionerJobByID(context.Background(), exp.ID)))
})

View File

@ -0,0 +1,29 @@
package provisionerjobs
import (
"encoding/json"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
const EventJobPosted = "provisioner_job_posted"
type JobPosting struct {
ProvisionerType database.ProvisionerType `json:"type"`
Tags map[string]string `json:"tags"`
}
func PostJob(ps pubsub.Pubsub, job database.ProvisionerJob) error {
msg, err := json.Marshal(JobPosting{
ProvisionerType: job.Provisioner,
Tags: job.Tags,
})
if err != nil {
return xerrors.Errorf("marshal job posting: %w", err)
}
err = ps.Publish(EventJobPosted, msg)
return err
}

View File

@ -103,7 +103,7 @@ func TestInsertWorkspaceAgentLogs(t *testing.T) {
require.NoError(t, err)
db := database.New(sqlDB)
org := dbgen.Organization(t, db, database.Organization{})
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
@ -335,7 +335,7 @@ func TestQueuePosition(t *testing.T) {
jobs := []database.ProvisionerJob{}
jobIDs := []uuid.UUID{}
for i := 0; i < jobCount; i++ {
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
Tags: database.StringMap{},
})

View File

@ -83,7 +83,7 @@ func setup(t testing.TB, db database.Store, authToken uuid.UUID, mw func(http.Ha
OrganizationID: org.ID,
TemplateID: template.ID,
})
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
})
resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{

View File

@ -34,7 +34,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
})
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
ID: build.JobID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Provisioner: database.ProvisionerTypeEcho,

View File

@ -363,7 +363,7 @@ func setupWorkspaceWithAgents(t testing.TB, cfg setupConfig) (database.Store, *h
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
})
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
ID: build.JobID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Provisioner: database.ProvisionerTypeEcho,

View File

@ -21,7 +21,7 @@ func TestWorkspaceResourceParam(t *testing.T) {
setup := func(t *testing.T, db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) {
r := httptest.NewRequest("GET", "/", nil)
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
Type: jobType,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,

View File

@ -15,14 +15,13 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
const (
EventJobPosted = "provisioner_job_posted"
dbMaxBackoff = 10 * time.Second
dbMaxBackoff = 10 * time.Second
// backPollDuration is the period for the backup polling described in Acquirer comment
backupPollDuration = 30 * time.Second
)
@ -106,8 +105,6 @@ func (a *Acquirer) AcquireJob(
}
// buffer of 1 so that cancel doesn't deadlock while writing to the channel
clearance := make(chan struct{}, 1)
//nolint:gocritic // Provisionerd has specific authz rules.
principal := dbauthz.AsProvisionerd(ctx)
for {
a.want(pt, tags, clearance)
select {
@ -122,7 +119,7 @@ func (a *Acquirer) AcquireJob(
return database.ProvisionerJob{}, err
case <-clearance:
logger.Debug(ctx, "got clearance to call database")
job, err := a.store.AcquireProvisionerJob(principal, database.AcquireProvisionerJobParams{
job, err := a.store.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
@ -298,7 +295,7 @@ func (a *Acquirer) subscribe() {
bkoff := backoff.WithContext(eb, a.ctx)
var cancel context.CancelFunc
err := backoff.Retry(func() error {
cancelFn, err := a.ps.SubscribeWithErr(EventJobPosted, a.jobPosted)
cancelFn, err := a.ps.SubscribeWithErr(provisionerjobs.EventJobPosted, a.jobPosted)
if err != nil {
a.logger.Warn(a.ctx, "failed to subscribe to job postings", slog.Error(err))
return err
@ -335,7 +332,7 @@ func (a *Acquirer) jobPosted(ctx context.Context, message []byte, err error) {
a.logger.Warn(a.ctx, "unhandled pubsub error", slog.Error(err))
return
}
posting := JobPosting{}
posting := provisionerjobs.JobPosting{}
err = json.Unmarshal(message, &posting)
if err != nil {
a.logger.Error(a.ctx, "unable to parse job posting",
@ -457,7 +454,7 @@ type domain struct {
acquirees map[chan<- struct{}]*acquiree
}
func (d domain) contains(p JobPosting) bool {
func (d domain) contains(p provisionerjobs.JobPosting) bool {
if !slices.Contains(d.pt, p.ProvisionerType) {
return false
}
@ -485,8 +482,3 @@ func (d domain) poll(dur time.Duration) {
}
}
}
type JobPosting struct {
ProvisionerType database.ProvisionerType `json:"type"`
Tags map[string]string `json:"tags"`
}

View File

@ -18,6 +18,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/testutil"
@ -316,12 +317,12 @@ func TestAcquirer_UnblockOnCancel(t *testing.T) {
func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) {
t.Helper()
msg, err := json.Marshal(provisionerdserver.JobPosting{
msg, err := json.Marshal(provisionerjobs.JobPosting{
ProvisionerType: pt,
Tags: tags,
})
require.NoError(t, err)
err = ps.Publish(provisionerdserver.EventJobPosted, msg)
err = ps.Publish(provisionerjobs.EventJobPosted, msg)
require.NoError(t, err)
}

View File

@ -11,7 +11,6 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@ -44,16 +43,18 @@ import (
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
)
var (
lastAcquire time.Time
lastAcquireMutex sync.RWMutex
)
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
// canceling and returning an empty job.
const DefaultAcquireJobLongPollDur = time.Second * 5
type Options struct {
OIDCConfig httpmw.OAuth2Config
GitAuthConfigs []*gitauth.Config
// TimeNowFn is only used in tests
TimeNowFn func() time.Time
// AcquireJobLongPollDur is used in tests
AcquireJobLongPollDur time.Duration
}
type server struct {
@ -62,9 +63,10 @@ type server struct {
Logger slog.Logger
Provisioners []database.ProvisionerType
GitAuthConfigs []*gitauth.Config
Tags json.RawMessage
Tags Tags
Database database.Store
Pubsub pubsub.Pubsub
Acquirer *Acquirer
Telemetry telemetry.Reporter
Tracer trace.Tracer
QuotaCommitter *atomic.Pointer[proto.QuotaCommitter]
@ -73,10 +75,11 @@ type server struct {
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
DeploymentValues *codersdk.DeploymentValues
AcquireJobDebounce time.Duration
OIDCConfig httpmw.OAuth2Config
OIDCConfig httpmw.OAuth2Config
TimeNowFn func() time.Time
acquireJobLongPollDur time.Duration
}
// We use the null byte (0x00) in generating a canonical map key for tags, so
@ -108,9 +111,10 @@ func NewServer(
id uuid.UUID,
logger slog.Logger,
provisioners []database.ProvisionerType,
tags json.RawMessage,
tags Tags,
db database.Store,
ps pubsub.Pubsub,
acquirer *Acquirer,
tel telemetry.Reporter,
tracer trace.Tracer,
quotaCommitter *atomic.Pointer[proto.QuotaCommitter],
@ -118,7 +122,6 @@ func NewServer(
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore],
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore],
deploymentValues *codersdk.DeploymentValues,
acquireJobDebounce time.Duration,
options Options,
) (proto.DRPCProvisionerDaemonServer, error) {
// Panic early if pointers are nil
@ -137,6 +140,18 @@ func NewServer(
if deploymentValues == nil {
return nil, xerrors.New("deploymentValues is nil")
}
if acquirer == nil {
return nil, xerrors.New("acquirer is nil")
}
if tags == nil {
return nil, xerrors.Errorf("tags is nil")
}
if err := tags.Valid(); err != nil {
return nil, xerrors.Errorf("invalid tags: %w", err)
}
if options.AcquireJobLongPollDur == 0 {
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
}
return &server{
AccessURL: accessURL,
ID: id,
@ -146,6 +161,7 @@ func NewServer(
Tags: tags,
Database: db,
Pubsub: ps,
Acquirer: acquirer,
Telemetry: tel,
Tracer: tracer,
QuotaCommitter: quotaCommitter,
@ -153,9 +169,9 @@ func NewServer(
TemplateScheduleStore: templateScheduleStore,
UserQuietHoursScheduleStore: userQuietHoursScheduleStore,
DeploymentValues: deploymentValues,
AcquireJobDebounce: acquireJobDebounce,
OIDCConfig: options.OIDCConfig,
TimeNowFn: options.TimeNowFn,
acquireJobLongPollDur: options.AcquireJobLongPollDur,
}, nil
}
@ -169,50 +185,119 @@ func (s *server) timeNow() time.Time {
}
// AcquireJob queries the database to lock a job.
//
// Deprecated: This method is only available for back-level provisioner daemons.
func (s *server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
//nolint:gocritic // Provisionerd has specific authz rules.
ctx = dbauthz.AsProvisionerd(ctx)
// This prevents loads of provisioner daemons from consistently
// querying the database when no jobs are available.
//
// The debounce only occurs when no job is returned, so if loads of
// jobs are added at once, they will start after at most this duration.
lastAcquireMutex.RLock()
if !lastAcquire.IsZero() && time.Since(lastAcquire) < s.AcquireJobDebounce {
s.Logger.Debug(ctx, "debounce acquire job", slog.F("debounce", s.AcquireJobDebounce), slog.F("last_acquire", lastAcquire))
lastAcquireMutex.RUnlock()
return &proto.AcquiredJob{}, nil
}
lastAcquireMutex.RUnlock()
// This marks the job as locked in the database.
job, err := s.Database.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
WorkerID: uuid.NullUUID{
UUID: s.ID,
Valid: true,
},
Types: s.Provisioners,
Tags: s.Tags,
})
if errors.Is(err, sql.ErrNoRows) {
// The provisioner daemon assumes no jobs are available if
// an empty struct is returned.
lastAcquireMutex.Lock()
lastAcquire = dbtime.Now()
lastAcquireMutex.Unlock()
// Since AcquireJob blocks until a job is available, we set a long (5s by default) timeout. This allows back-level
// provisioner daemons to gracefully shut down within a few seconds, but keeps them from rapidly polling the
// database.
acqCtx, acqCancel := context.WithTimeout(ctx, s.acquireJobLongPollDur)
defer acqCancel()
job, err := s.Acquirer.AcquireJob(acqCtx, s.ID, s.Provisioners, s.Tags)
if xerrors.Is(err, context.DeadlineExceeded) {
s.Logger.Debug(ctx, "successful cancel")
return &proto.AcquiredJob{}, nil
}
if err != nil {
return nil, xerrors.Errorf("acquire job: %w", err)
}
s.Logger.Debug(ctx, "locked job from database", slog.F("job_id", job.ID))
return s.acquireProtoJob(ctx, job)
}
type jobAndErr struct {
job database.ProvisionerJob
err error
}
// AcquireJobWithCancel queries the database to lock a job.
func (s *server) AcquireJobWithCancel(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) (retErr error) {
//nolint:gocritic // Provisionerd has specific authz rules.
streamCtx := dbauthz.AsProvisionerd(stream.Context())
defer func() {
closeErr := stream.Close()
s.Logger.Debug(streamCtx, "closed stream", slog.Error(closeErr))
if retErr == nil {
retErr = closeErr
}
}()
acqCtx, acqCancel := context.WithCancel(streamCtx)
defer acqCancel()
recvCh := make(chan error, 1)
go func() {
_, err := stream.Recv() // cancel is the only message
recvCh <- err
}()
jec := make(chan jobAndErr, 1)
go func() {
job, err := s.Acquirer.AcquireJob(acqCtx, s.ID, s.Provisioners, s.Tags)
jec <- jobAndErr{job: job, err: err}
}()
var recvErr error
var je jobAndErr
select {
case recvErr = <-recvCh:
acqCancel()
je = <-jec
case je = <-jec:
}
if xerrors.Is(je.err, context.Canceled) {
s.Logger.Debug(streamCtx, "successful cancel")
err := stream.Send(&proto.AcquiredJob{})
if err != nil {
// often this is just because the other side hangs up and doesn't wait for the cancel, so log at INFO
s.Logger.Info(streamCtx, "failed to send empty job", slog.Error(err))
return err
}
return nil
}
if je.err != nil {
return xerrors.Errorf("acquire job: %w", je.err)
}
logger := s.Logger.With(slog.F("job_id", je.job.ID))
logger.Debug(streamCtx, "locked job from database")
if recvErr != nil {
logger.Error(streamCtx, "recv error and failed to cancel acquire job", slog.Error(recvErr))
// Well, this is awkward. We hit an error receiving from the stream, but didn't cancel before we locked a job
// in the database. We need to mark this job as failed so the end user can retry if they want to.
err := s.Database.UpdateProvisionerJobWithCompleteByID(
context.Background(),
database.UpdateProvisionerJobWithCompleteByIDParams{
ID: je.job.ID,
CompletedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
Error: sql.NullString{
String: "connection to provisioner daemon broken",
Valid: true,
},
})
if err != nil {
logger.Error(streamCtx, "error updating failed job", slog.Error(err))
}
return recvErr
}
pj, err := s.acquireProtoJob(streamCtx, je.job)
if err != nil {
return err
}
err = stream.Send(pj)
if err != nil {
s.Logger.Error(streamCtx, "failed to send job", slog.Error(err))
return err
}
return nil
}
func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJob) (*proto.AcquiredJob, error) {
// Marks the acquired job as failed with the error message provided.
failJob := func(errorMessage string) error {
err = s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
err := s.Database.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: job.ID,
CompletedAt: sql.NullTime{
Time: dbtime.Now(),

View File

@ -4,12 +4,19 @@ import (
"context"
"database/sql"
"encoding/json"
"io"
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/xerrors"
"storj.io/drpc"
"cdr.dev/slog"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -57,400 +64,411 @@ func testUserQuietHoursScheduleStore() *atomic.Pointer[schedule.UserQuietHoursSc
return ptr
}
func TestAcquireJob_LongPoll(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, &overrides{acquireJobLongPollDuration: time.Microsecond})
job, err := srv.AcquireJob(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, &proto.AcquiredJob{}, job)
}
func TestAcquireJobWithCancel_Cancel(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
fs := newFakeStream(ctx)
errCh := make(chan error)
go func() {
errCh <- srv.AcquireJobWithCancel(fs)
}()
fs.cancel()
select {
case <-ctx.Done():
t.Fatal("timed out waiting for AcquireJobWithCancel")
case err := <-errCh:
require.NoError(t, err)
}
job, err := fs.waitForJob()
require.NoError(t, err)
require.NotNil(t, job)
require.Equal(t, "", job.JobId)
}
func TestAcquireJob(t *testing.T) {
t.Parallel()
t.Run("Debounce", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
ps := pubsub.NewInMemory()
srv, err := provisionerdserver.NewServer(
&url.URL{},
uuid.New(),
slogtest.Make(t, nil),
[]database.ProvisionerType{database.ProvisionerTypeEcho},
nil,
db,
ps,
telemetry.NewNoop(),
trace.NewNoopTracerProvider().Tracer("noop"),
&atomic.Pointer[proto.QuotaCommitter]{},
mockAuditor(),
testTemplateScheduleStore(),
testUserQuietHoursScheduleStore(),
&codersdk.DeploymentValues{},
time.Hour,
provisionerdserver.Options{},
)
require.NoError(t, err)
job, err := srv.AcquireJob(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, &proto.AcquiredJob{}, job)
_, err = db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
ID: uuid.New(),
InitiatorID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
})
require.NoError(t, err)
job, err = srv.AcquireJob(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, &proto.AcquiredJob{}, job)
})
t.Run("NoJobs", func(t *testing.T) {
t.Parallel()
srv, _, _ := setup(t, false, nil)
job, err := srv.AcquireJob(context.Background(), nil)
require.NoError(t, err)
require.Equal(t, &proto.AcquiredJob{}, job)
})
t.Run("InitiatorNotFound", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
_, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
ID: uuid.New(),
InitiatorID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
})
require.NoError(t, err)
_, err = srv.AcquireJob(context.Background(), nil)
require.ErrorContains(t, err, "sql: no rows in result set")
})
t.Run("WorkspaceBuildJob", func(t *testing.T) {
t.Parallel()
// Set the max session token lifetime so we can assert we
// create an API key with an expiration within the bounds of the
// deployment config.
dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)}
gitAuthProvider := "github"
srv, db, ps := setup(t, false, &overrides{
deploymentValues: dv,
gitAuthConfigs: []*gitauth.Config{{
ID: gitAuthProvider,
OAuth2Config: &testutil.OAuth2Config{},
}},
})
ctx := context.Background()
user := dbgen.User(t, db, database.User{})
link := dbgen.UserLink(t, db, database.UserLink{
LoginType: database.LoginTypeOIDC,
UserID: user.ID,
OAuthExpiry: dbtime.Now().Add(time.Hour),
OAuthAccessToken: "access-token",
})
dbgen.GitAuthLink(t, db, database.GitAuthLink{
ProviderID: gitAuthProvider,
UserID: user.ID,
})
template := dbgen.Template(t, db, database.Template{
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
JobID: uuid.New(),
})
err := db.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
JobID: version.JobID,
GitAuthProviders: []string{gitAuthProvider},
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err)
// Import version job
_ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
ID: version.JobID,
InitiatorID: user.ID,
FileID: versionFile.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{
TemplateVersionID: version.ID,
UserVariableValues: []codersdk.VariableValue{
{Name: "second", Value: "bah"},
},
})),
})
_ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{
TemplateVersionID: version.ID,
Name: "first",
Value: "first_value",
DefaultValue: "default_value",
Sensitive: true,
})
_ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{
TemplateVersionID: version.ID,
Name: "second",
Value: "second_value",
DefaultValue: "default_value",
Required: true,
Sensitive: false,
})
workspace := dbgen.Workspace(t, db, database.Workspace{
TemplateID: template.ID,
OwnerID: user.ID,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
BuildNumber: 1,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
})
_ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
ID: build.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
})),
})
startPublished := make(chan struct{})
var closed bool
closeStartSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
if !closed {
close(startPublished)
closed = true
// These test acquiring a single job without canceling, and tests both AcquireJob (deprecated) and
// AcquireJobWithCancel as the way to get the job.
cases := []struct {
name string
acquire func(context.Context, proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error)
}{
{name: "Deprecated", acquire: func(ctx context.Context, srv proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error) {
return srv.AcquireJob(ctx, nil)
}},
{name: "WithCancel", acquire: func(ctx context.Context, srv proto.DRPCProvisionerDaemonServer) (*proto.AcquiredJob, error) {
fs := newFakeStream(ctx)
err := srv.AcquireJobWithCancel(fs)
if err != nil {
return nil, err
}
return fs.waitForJob()
}},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name+"_InitiatorNotFound", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
_, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
ID: uuid.New(),
InitiatorID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
})
require.NoError(t, err)
_, err = tc.acquire(ctx, srv)
require.ErrorContains(t, err, "sql: no rows in result set")
})
require.NoError(t, err)
defer closeStartSubscribe()
t.Run(tc.name+"_WorkspaceBuildJob", func(t *testing.T) {
t.Parallel()
// Set the max session token lifetime so we can assert we
// create an API key with an expiration within the bounds of the
// deployment config.
dv := &codersdk.DeploymentValues{MaxTokenLifetime: clibase.Duration(time.Hour)}
gitAuthProvider := "github"
srv, db, ps := setup(t, false, &overrides{
deploymentValues: dv,
gitAuthConfigs: []*gitauth.Config{{
ID: gitAuthProvider,
OAuth2Config: &testutil.OAuth2Config{},
}},
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
var job *proto.AcquiredJob
user := dbgen.User(t, db, database.User{})
link := dbgen.UserLink(t, db, database.UserLink{
LoginType: database.LoginTypeOIDC,
UserID: user.ID,
OAuthExpiry: dbtime.Now().Add(time.Hour),
OAuthAccessToken: "access-token",
})
dbgen.GitAuthLink(t, db, database.GitAuthLink{
ProviderID: gitAuthProvider,
UserID: user.ID,
})
template := dbgen.Template(t, db, database.Template{
Name: "template",
Provisioner: database.ProvisionerTypeEcho,
})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{
UUID: template.ID,
Valid: true,
},
JobID: uuid.New(),
})
err := db.UpdateTemplateVersionGitAuthProvidersByJobID(ctx, database.UpdateTemplateVersionGitAuthProvidersByJobIDParams{
JobID: version.JobID,
GitAuthProviders: []string{gitAuthProvider},
UpdatedAt: dbtime.Now(),
})
require.NoError(t, err)
// Import version job
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
ID: version.JobID,
InitiatorID: user.ID,
FileID: versionFile.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{
TemplateVersionID: version.ID,
UserVariableValues: []codersdk.VariableValue{
{Name: "second", Value: "bah"},
},
})),
})
_ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{
TemplateVersionID: version.ID,
Name: "first",
Value: "first_value",
DefaultValue: "default_value",
Sensitive: true,
})
_ = dbgen.TemplateVersionVariable(t, db, database.TemplateVersionVariable{
TemplateVersionID: version.ID,
Name: "second",
Value: "second_value",
DefaultValue: "default_value",
Required: true,
Sensitive: false,
})
workspace := dbgen.Workspace(t, db, database.Workspace{
TemplateID: template.ID,
OwnerID: user.ID,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
BuildNumber: 1,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
})
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
ID: build.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
})),
})
startPublished := make(chan struct{})
var closed bool
closeStartSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
if !closed {
close(startPublished)
closed = true
}
})
require.NoError(t, err)
defer closeStartSubscribe()
var job *proto.AcquiredJob
for {
// Grab jobs until we find the workspace build job. There is also
// an import version job that we need to ignore.
job, err = tc.acquire(ctx, srv)
require.NoError(t, err)
if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok {
break
}
}
<-startPublished
got, err := json.Marshal(job.Type)
require.NoError(t, err)
// Validate that a session token is generated during the job.
sessionToken := job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
require.NotEmpty(t, sessionToken)
toks := strings.Split(sessionToken, "-")
require.Len(t, toks, 2, "invalid api key")
key, err := db.GetAPIKeyByID(ctx, toks[0])
require.NoError(t, err)
require.Equal(t, int64(dv.MaxTokenLifetime.Value().Seconds()), key.LifetimeSeconds)
require.WithinDuration(t, time.Now().Add(dv.MaxTokenLifetime.Value()), key.ExpiresAt, time.Minute)
want, err := json.Marshal(&proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
WorkspaceBuildId: build.ID.String(),
WorkspaceName: workspace.Name,
VariableValues: []*sdkproto.VariableValue{
{
Name: "first",
Value: "first_value",
Sensitive: true,
},
{
Name: "second",
Value: "second_value",
},
},
GitAuthProviders: []*sdkproto.GitAuthProvider{{
Id: gitAuthProvider,
AccessToken: "access_token",
}},
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
WorkspaceTransition: sdkproto.WorkspaceTransition_START,
WorkspaceName: workspace.Name,
WorkspaceOwner: user.Username,
WorkspaceOwnerEmail: user.Email,
WorkspaceOwnerOidcAccessToken: link.OAuthAccessToken,
WorkspaceId: workspace.ID.String(),
WorkspaceOwnerId: user.ID.String(),
TemplateId: template.ID.String(),
TemplateName: template.Name,
TemplateVersion: version.Name,
WorkspaceOwnerSessionToken: sessionToken,
},
},
})
require.NoError(t, err)
require.JSONEq(t, string(want), string(got))
// Assert that we delete the session token whenever
// a stop is issued.
stopbuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
BuildNumber: 2,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStop,
Reason: database.BuildReasonInitiator,
})
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
ID: stopbuild.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: stopbuild.ID,
})),
})
stopPublished := make(chan struct{})
closeStopSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
close(stopPublished)
})
require.NoError(t, err)
defer closeStopSubscribe()
for {
// Grab jobs until we find the workspace build job. There is also
// an import version job that we need to ignore.
job, err = srv.AcquireJob(ctx, nil)
job, err = tc.acquire(ctx, srv)
require.NoError(t, err)
if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok {
break
}
}
_, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_)
require.True(t, ok, "acquired job not a workspace build?")
<-startPublished
<-stopPublished
got, err := json.Marshal(job.Type)
require.NoError(t, err)
// Validate that a session token is deleted during a stop job.
sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
require.Empty(t, sessionToken)
_, err = db.GetAPIKeyByID(ctx, key.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
})
// Validate that a session token is generated during the job.
sessionToken := job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
require.NotEmpty(t, sessionToken)
toks := strings.Split(sessionToken, "-")
require.Len(t, toks, 2, "invalid api key")
key, err := db.GetAPIKeyByID(ctx, toks[0])
require.NoError(t, err)
require.Equal(t, int64(dv.MaxTokenLifetime.Value().Seconds()), key.LifetimeSeconds)
require.WithinDuration(t, time.Now().Add(dv.MaxTokenLifetime.Value()), key.ExpiresAt, time.Minute)
t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) {
t.Parallel()
srv, db, ps := setup(t, false, nil)
ctx := context.Background()
want, err := json.Marshal(&proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
WorkspaceBuildId: build.ID.String(),
WorkspaceName: workspace.Name,
VariableValues: []*sdkproto.VariableValue{
{
Name: "first",
Value: "first_value",
Sensitive: true,
},
{
Name: "second",
Value: "second_value",
user := dbgen.User(t, db, database.User{})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{
TemplateVersionID: version.ID,
WorkspaceName: "testing",
})),
})
job, err := tc.acquire(ctx, srv)
require.NoError(t, err)
got, err := json.Marshal(job.Type)
require.NoError(t, err)
want, err := json.Marshal(&proto.AcquiredJob_TemplateDryRun_{
TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
WorkspaceName: "testing",
},
},
GitAuthProviders: []*sdkproto.GitAuthProvider{{
Id: gitAuthProvider,
AccessToken: "access_token",
}},
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
WorkspaceTransition: sdkproto.WorkspaceTransition_START,
WorkspaceName: workspace.Name,
WorkspaceOwner: user.Username,
WorkspaceOwnerEmail: user.Email,
WorkspaceOwnerOidcAccessToken: link.OAuthAccessToken,
WorkspaceId: workspace.ID.String(),
WorkspaceOwnerId: user.ID.String(),
TemplateId: template.ID.String(),
TemplateName: template.Name,
TemplateVersion: version.Name,
WorkspaceOwnerSessionToken: sessionToken,
})
require.NoError(t, err)
require.JSONEq(t, string(want), string(got))
})
t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) {
t.Parallel()
srv, db, ps := setup(t, false, nil)
ctx := context.Background()
user := dbgen.User(t, db, database.User{})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
FileID: file.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
})
job, err := tc.acquire(ctx, srv)
require.NoError(t, err)
got, err := json.Marshal(job.Type)
require.NoError(t, err)
want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
},
},
},
})
require.NoError(t, err)
require.JSONEq(t, string(want), string(got))
})
require.NoError(t, err)
t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) {
t.Parallel()
srv, db, ps := setup(t, false, nil)
require.JSONEq(t, string(want), string(got))
user := dbgen.User(t, db, database.User{})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
FileID: file.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{
TemplateVersionID: version.ID,
UserVariableValues: []codersdk.VariableValue{
{Name: "first", Value: "first_value"},
},
})),
})
// Assert that we delete the session token whenever
// a stop is issued.
stopbuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
BuildNumber: 2,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStop,
Reason: database.BuildReasonInitiator,
})
_ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
ID: stopbuild.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: stopbuild.ID,
})),
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
stopPublished := make(chan struct{})
closeStopSubscribe, err := ps.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
close(stopPublished)
})
require.NoError(t, err)
defer closeStopSubscribe()
job, err := tc.acquire(ctx, srv)
require.NoError(t, err)
// Grab jobs until we find the workspace build job. There is also
// an import version job that we need to ignore.
job, err = srv.AcquireJob(ctx, nil)
require.NoError(t, err)
_, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_)
require.True(t, ok, "acquired job not a workspace build?")
got, err := json.Marshal(job.Type)
require.NoError(t, err)
<-stopPublished
// Validate that a session token is deleted during a stop job.
sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
require.Empty(t, sessionToken)
_, err = db.GetAPIKeyByID(ctx, key.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
})
t.Run("TemplateVersionDryRun", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
ctx := context.Background()
user := dbgen.User(t, db, database.User{})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
_ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: must(json.Marshal(provisionerdserver.TemplateVersionDryRunJob{
TemplateVersionID: version.ID,
WorkspaceName: "testing",
})),
})
job, err := srv.AcquireJob(ctx, nil)
require.NoError(t, err)
got, err := json.Marshal(job.Type)
require.NoError(t, err)
want, err := json.Marshal(&proto.AcquiredJob_TemplateDryRun_{
TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
WorkspaceName: "testing",
want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
UserVariableValues: []*sdkproto.VariableValue{
{Name: "first", Sensitive: true, Value: "first_value"},
},
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
},
},
},
})
require.NoError(t, err)
require.JSONEq(t, string(want), string(got))
})
require.NoError(t, err)
require.JSONEq(t, string(want), string(got))
})
t.Run("TemplateVersionImport", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
ctx := context.Background()
user := dbgen.User(t, db, database.User{})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
_ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
FileID: file.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
})
job, err := srv.AcquireJob(ctx, nil)
require.NoError(t, err)
got, err := json.Marshal(job.Type)
require.NoError(t, err)
want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
},
},
})
require.NoError(t, err)
require.JSONEq(t, string(want), string(got))
})
t.Run("TemplateVersionImportWithUserVariable", func(t *testing.T) {
t.Parallel()
srv, db, _ := setup(t, false, nil)
user := dbgen.User(t, db, database.User{})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
file := dbgen.File(t, db, database.File{CreatedBy: user.ID})
_ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
FileID: file.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: must(json.Marshal(provisionerdserver.TemplateVersionImportJob{
TemplateVersionID: version.ID,
UserVariableValues: []codersdk.VariableValue{
{Name: "first", Value: "first_value"},
},
})),
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
job, err := srv.AcquireJob(ctx, nil)
require.NoError(t, err)
got, err := json.Marshal(job.Type)
require.NoError(t, err)
want, err := json.Marshal(&proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
UserVariableValues: []*sdkproto.VariableValue{
{Name: "first", Sensitive: true, Value: "first_value"},
},
Metadata: &sdkproto.Metadata{
CoderUrl: (&url.URL{}).String(),
},
},
})
require.NoError(t, err)
require.JSONEq(t, string(want), string(got))
})
}
}
func TestUpdateJob(t *testing.T) {
@ -1142,7 +1160,7 @@ func TestCompleteJob(t *testing.T) {
Transition: c.transition,
Reason: database.BuildReasonInitiator,
})
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
@ -1390,7 +1408,7 @@ func TestCompleteJob(t *testing.T) {
Transition: c.transition,
Reason: database.BuildReasonInitiator,
})
job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
job := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
@ -1662,10 +1680,14 @@ type overrides struct {
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
timeNowFn func() time.Time
acquireJobLongPollDuration time.Duration
}
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
db := dbfake.New()
ps := pubsub.NewInMemory()
deploymentValues := &codersdk.DeploymentValues{}
@ -1674,6 +1696,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
tss := testTemplateScheduleStore()
uqhss := testUserQuietHoursScheduleStore()
var timeNowFn func() time.Time
pollDur := time.Duration(0)
if ov != nil {
if ov.deploymentValues != nil {
deploymentValues = ov.deploymentValues
@ -1705,6 +1728,7 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
if ov.timeNowFn != nil {
timeNowFn = ov.timeNowFn
}
pollDur = ov.acquireJobLongPollDuration
}
srv, err := provisionerdserver.NewServer(
@ -1712,9 +1736,10 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
srvID,
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
[]database.ProvisionerType{database.ProvisionerTypeEcho},
nil,
provisionerdserver.Tags{},
db,
ps,
provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps),
telemetry.NewNoop(),
trace.NewNoopTracerProvider().Tracer("noop"),
&atomic.Pointer[proto.QuotaCommitter]{},
@ -1722,12 +1747,11 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
tss,
uqhss,
deploymentValues,
// Negative values cause the debounce to never kick in.
-time.Minute,
provisionerdserver.Options{
GitAuthConfigs: gitAuthConfigs,
TimeNowFn: timeNowFn,
OIDCConfig: &oauth2.Config{},
GitAuthConfigs: gitAuthConfigs,
TimeNowFn: timeNowFn,
OIDCConfig: &oauth2.Config{},
AcquireJobLongPollDur: pollDur,
},
)
require.NoError(t, err)
@ -1740,3 +1764,95 @@ func must[T any](value T, err error) T {
}
return value
}
var (
errUnimplemented = xerrors.New("unimplemented")
errClosed = xerrors.New("closed")
)
type fakeStream struct {
ctx context.Context
c *sync.Cond
closed bool
canceled bool
sendCalled bool
job *proto.AcquiredJob
}
func newFakeStream(ctx context.Context) *fakeStream {
return &fakeStream{
ctx: ctx,
c: sync.NewCond(&sync.Mutex{}),
}
}
func (s *fakeStream) Send(j *proto.AcquiredJob) error {
s.c.L.Lock()
defer s.c.L.Unlock()
s.sendCalled = true
s.job = j
s.c.Broadcast()
return nil
}
func (s *fakeStream) Recv() (*proto.CancelAcquire, error) {
s.c.L.Lock()
defer s.c.L.Unlock()
for !(s.canceled || s.closed) {
s.c.Wait()
}
if s.canceled {
return &proto.CancelAcquire{}, nil
}
return nil, io.EOF
}
// Context returns the context associated with the stream. It is canceled
// when the Stream is closed and no more messages will ever be sent or
// received on it.
func (s *fakeStream) Context() context.Context {
return s.ctx
}
// MsgSend sends the Message to the remote.
func (*fakeStream) MsgSend(drpc.Message, drpc.Encoding) error {
return errUnimplemented
}
// MsgRecv receives a Message from the remote.
func (*fakeStream) MsgRecv(drpc.Message, drpc.Encoding) error {
return errUnimplemented
}
// CloseSend signals to the remote that we will no longer send any messages.
func (*fakeStream) CloseSend() error {
return errUnimplemented
}
// Close closes the stream.
func (s *fakeStream) Close() error {
s.c.L.Lock()
defer s.c.L.Unlock()
s.closed = true
s.c.Broadcast()
return nil
}
func (s *fakeStream) waitForJob() (*proto.AcquiredJob, error) {
s.c.L.Lock()
defer s.c.L.Unlock()
for !(s.sendCalled || s.closed) {
s.c.Wait()
}
if s.sendCalled {
return s.job, nil
}
return nil, errClosed
}
func (s *fakeStream) cancel() {
s.c.L.Lock()
defer s.c.L.Unlock()
s.canceled = true
s.c.Broadcast()
}

View File

@ -40,7 +40,7 @@ func TestTelemetry(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitMedium)
_, _ = dbgen.APIKey(t, db, database.APIKey{})
_ = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
_ = dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
Provisioner: database.ProvisionerTypeTerraform,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,

View File

@ -21,6 +21,7 @@ import (
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/gitauth"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
@ -502,6 +503,11 @@ func (api *API) postTemplateVersionDryRun(rw http.ResponseWriter, r *http.Reques
})
return
}
err = provisionerjobs.PostJob(api.Pubsub, provisionerJob)
if err != nil {
// Client probably doesn't care about this error, so just log it.
api.Logger.Error(ctx, "failed to post provisioner job to pubsub", slog.Error(err))
}
httpapi.Write(ctx, rw, http.StatusCreated, convertProvisionerJob(database.GetProvisionerJobsByIDsWithQueuePositionRow{
ProvisionerJob: provisionerJob,
@ -1289,6 +1295,11 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
return
}
aReq.New = templateVersion
err = provisionerjobs.PostJob(api.Pubsub, provisionerJob)
if err != nil {
// Client probably doesn't care about this error, so just log it.
api.Logger.Error(ctx, "failed to post provisioner job to pubsub", slog.Error(err))
}
httpapi.Write(ctx, rw, http.StatusCreated, convertTemplateVersion(templateVersion, convertProvisionerJob(database.GetProvisionerJobsByIDsWithQueuePositionRow{
ProvisionerJob: provisionerJob,

View File

@ -67,7 +67,7 @@ func TestDetectorNoHungJobs(t *testing.T) {
user := dbgen.User(t, db, database.User{})
file := dbgen.File(t, db, database.File{})
for i := 0; i < 5; i++ {
dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: now.Add(-time.Minute * 5),
UpdatedAt: now.Add(-time.Minute * time.Duration(i)),
StartedAt: sql.NullTime{
@ -135,7 +135,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
// Previous build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: twentyMinAgo,
UpdatedAt: twentyMinAgo,
StartedAt: sql.NullTime{
@ -163,7 +163,7 @@ func TestDetectorHungWorkspaceBuild(t *testing.T) {
})
// Current build.
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
@ -256,7 +256,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
})
// Previous build.
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
previousWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: twentyMinAgo,
UpdatedAt: twentyMinAgo,
StartedAt: sql.NullTime{
@ -285,7 +285,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideState(t *testing.T) {
// Current build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
@ -379,7 +379,7 @@ func TestDetectorHungWorkspaceBuildNoOverrideStateIfNoExistingBuild(t *testing.T
// First build.
expectedWorkspaceBuildState = []byte(`{"dean":"cool","colin":"also cool"}`)
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
currentWorkspaceBuildJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
@ -454,7 +454,7 @@ func TestDetectorHungOtherJobTypes(t *testing.T) {
file = dbgen.File(t, db, database.File{})
// Template import job.
templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
templateImportJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
@ -471,7 +471,7 @@ func TestDetectorHungOtherJobTypes(t *testing.T) {
})
// Template dry-run job.
templateDryRunJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
templateDryRunJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
@ -545,7 +545,7 @@ func TestDetectorHungCanceledJob(t *testing.T) {
file = dbgen.File(t, db, database.File{})
// Template import job.
templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
templateImportJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
CanceledAt: sql.NullTime{
Time: tenMinAgo,
@ -642,7 +642,7 @@ func TestDetectorPushesLogs(t *testing.T) {
file = dbgen.File(t, db, database.File{})
// Template import job.
templateImportJob = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
templateImportJob = dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: tenMinAgo,
UpdatedAt: sixMinAgo,
StartedAt: sql.NullTime{
@ -752,7 +752,7 @@ func TestDetectorMaxJobsPerRun(t *testing.T) {
// Create unhanger.MaxJobsPerRun + 1 hung jobs.
now := time.Now()
for i := 0; i < unhanger.MaxJobsPerRun+1; i++ {
dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
StartedAt: sql.NullTime{

View File

@ -20,6 +20,7 @@ import (
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
@ -373,6 +374,11 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
})
return
}
err = provisionerjobs.PostJob(api.Pubsub, *provisionerJob)
if err != nil {
// Client probably doesn't care about this error, so just log it.
api.Logger.Error(ctx, "failed to post provisioner job to pubsub", slog.Error(err))
}
users, err := api.Database.GetUsersByIDs(ctx, []uuid.UUID{
workspace.OwnerID,

View File

@ -19,6 +19,7 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac"
@ -485,7 +486,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
}
workspaceBuild, provisionerJob, err = builder.Build(
ctx, db, func(action rbac.Action, object rbac.Objecter) bool {
ctx,
db,
func(action rbac.Action, object rbac.Objecter) bool {
return api.Authorize(r, action, object)
})
return err
@ -505,6 +508,11 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
})
return
}
err = provisionerjobs.PostJob(api.Pubsub, *provisionerJob)
if err != nil {
// Client probably doesn't care about this error, so just log it.
api.Logger.Error(ctx, "failed to post provisioner job to pubsub", slog.Error(err))
}
aReq.New = workspace
initiator, err := api.Database.GetUserByID(ctx, workspaceBuild.InitiatorID)

View File

@ -789,7 +789,7 @@ func TestWorkspaceFilterAllStatus(t *testing.T) {
file := dbgen.File(t, db, database.File{
CreatedBy: owner.UserID,
})
versionJob := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{
versionJob := dbgen.ProvisionerJob(t, db, pubsub, database.ProvisionerJob{
OrganizationID: owner.OrganizationID,
InitiatorID: owner.UserID,
WorkerID: uuid.NullUUID{},
@ -825,7 +825,7 @@ func TestWorkspaceFilterAllStatus(t *testing.T) {
job.Tags = database.StringMap{
jobID.String(): "true",
}
job = dbgen.ProvisionerJob(t, db, job)
job = dbgen.ProvisionerJob(t, db, pubsub, job)
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,