chore(coderd/provisionerdserver): convert dbmem tests to use postgres (#18278)

This commit is contained in:
Hugo Dutka
2025-06-09 10:05:29 +02:00
committed by GitHub
parent 7d8b994229
commit 910858b731
3 changed files with 270 additions and 128 deletions

View File

@ -18,7 +18,6 @@ import (
"go.uber.org/goleak" "go.uber.org/goleak"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/provisionerjobs" "github.com/coder/coder/v2/coderd/database/provisionerjobs"
@ -34,8 +33,7 @@ func TestMain(m *testing.M) {
// TestAcquirer_Store tests that a database.Store is accepted as a provisionerdserver.AcquirerStore // TestAcquirer_Store tests that a database.Store is accepted as a provisionerdserver.AcquirerStore
func TestAcquirer_Store(t *testing.T) { func TestAcquirer_Store(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, ps := dbtestutil.NewDB(t)
ps := pubsub.NewInMemory()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel() defer cancel()
logger := testutil.Logger(t) logger := testutil.Logger(t)

View File

@ -11,7 +11,7 @@ import (
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmem" "github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -21,14 +21,14 @@ func TestObtainOIDCAccessToken(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("NoToken", func(t *testing.T) { t.Run("NoToken", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
_, err := obtainOIDCAccessToken(ctx, db, nil, uuid.Nil) _, err := obtainOIDCAccessToken(ctx, db, nil, uuid.Nil)
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("InvalidConfig", func(t *testing.T) { t.Run("InvalidConfig", func(t *testing.T) {
// We still want OIDC to succeed even if exchanging the token fails. // We still want OIDC to succeed even if exchanging the token fails.
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
dbgen.UserLink(t, db, database.UserLink{ dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID, UserID: user.ID,
@ -40,7 +40,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
}) })
t.Run("MissingLink", func(t *testing.T) { t.Run("MissingLink", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{ user := dbgen.User(t, db, database.User{
LoginType: database.LoginTypeOIDC, LoginType: database.LoginTypeOIDC,
}) })
@ -50,7 +50,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
}) })
t.Run("Exchange", func(t *testing.T) { t.Run("Exchange", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
dbgen.UserLink(t, db, database.UserLink{ dbgen.UserLink(t, db, database.UserLink{
UserID: user.ID, UserID: user.ID,

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/url" "net/url"
"slices" "slices"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -31,7 +32,6 @@ import (
"github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/database/pubsub"
@ -164,6 +164,8 @@ func TestAcquireJob(t *testing.T) {
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun, Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = tc.acquire(ctx, srv) _, err = tc.acquire(ctx, srv)
@ -175,7 +177,7 @@ func TestAcquireJob(t *testing.T) {
sdkproto.PrebuiltWorkspaceBuildStage_CLAIM, sdkproto.PrebuiltWorkspaceBuildStage_CLAIM,
} { } {
prebuiltWorkspaceBuildStage := prebuiltWorkspaceBuildStage prebuiltWorkspaceBuildStage := prebuiltWorkspaceBuildStage
t.Run(tc.name+"_WorkspaceBuildJob", func(t *testing.T) { t.Run(tc.name+"_WorkspaceBuildJob_Stage"+prebuiltWorkspaceBuildStage.String(), func(t *testing.T) {
t.Parallel() t.Parallel()
// Set the max session token lifetime so we can assert we // Set the max session token lifetime so we can assert we
// create an API key with an expiration within the bounds of the // create an API key with an expiration within the bounds of the
@ -240,10 +242,12 @@ func TestAcquireJob(t *testing.T) {
Name: "template", Name: "template",
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
CreatedBy: user.ID,
}) })
file := dbgen.File(t, db, database.File{CreatedBy: user.ID}) file := dbgen.File(t, db, database.File{CreatedBy: user.ID, Hash: "1"})
versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID}) versionFile := dbgen.File(t, db, database.File{CreatedBy: user.ID, Hash: "2"})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{ version := dbgen.TemplateVersion(t, db, database.TemplateVersion{
CreatedBy: user.ID,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
TemplateID: uuid.NullUUID{ TemplateID: uuid.NullUUID{
UUID: template.ID, UUID: template.ID,
@ -293,35 +297,33 @@ func TestAcquireJob(t *testing.T) {
Required: true, Required: true,
Sensitive: false, Sensitive: false,
}) })
workspace := database.WorkspaceTable{ workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
TemplateID: template.ID, TemplateID: template.ID,
OwnerID: user.ID, OwnerID: user.ID,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
} })
workspace = dbgen.Workspace(t, db, workspace) buildID := uuid.New()
build := database.WorkspaceBuild{ dbJob := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
WorkspaceID: workspace.ID,
BuildNumber: 1,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
}
build = dbgen.WorkspaceBuild(t, db, build)
input := provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
}
dbJob := database.ProvisionerJob{
ID: build.JobID,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
InitiatorID: user.ID, InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID, FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(input)), Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
} WorkspaceBuildID: buildID,
dbJob = dbgen.ProvisionerJob(t, db, ps, dbJob) })),
Tags: pd.Tags,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
ID: buildID,
WorkspaceID: workspace.ID,
BuildNumber: 1,
JobID: dbJob.ID,
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
})
var agent database.WorkspaceAgent var agent database.WorkspaceAgent
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM { if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
@ -332,25 +334,12 @@ func TestAcquireJob(t *testing.T) {
ResourceID: resource.ID, ResourceID: resource.ID,
AuthToken: uuid.New(), AuthToken: uuid.New(),
}) })
// At this point we have an unclaimed workspace and build, now we need to setup the claim buildID := uuid.New()
// build input := provisionerdserver.WorkspaceProvisionJob{
build = database.WorkspaceBuild{ WorkspaceBuildID: buildID,
WorkspaceID: workspace.ID,
BuildNumber: 2,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
InitiatorID: user.ID,
}
build = dbgen.WorkspaceBuild(t, db, build)
input = provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: build.ID,
PrebuiltWorkspaceBuildStage: prebuiltWorkspaceBuildStage, PrebuiltWorkspaceBuildStage: prebuiltWorkspaceBuildStage,
} }
dbJob = database.ProvisionerJob{ dbJob = database.ProvisionerJob{
ID: build.JobID,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
InitiatorID: user.ID, InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
@ -358,8 +347,22 @@ func TestAcquireJob(t *testing.T) {
FileID: file.ID, FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(input)), Input: must(json.Marshal(input)),
Tags: pd.Tags,
} }
dbJob = dbgen.ProvisionerJob(t, db, ps, dbJob) dbJob = dbgen.ProvisionerJob(t, db, ps, dbJob)
// At this point we have an unclaimed workspace and build, now we need to setup the claim
// build.
build = database.WorkspaceBuild{
ID: buildID,
WorkspaceID: workspace.ID,
BuildNumber: 2,
JobID: dbJob.ID,
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
InitiatorID: user.ID,
}
build = dbgen.WorkspaceBuild(t, db, build)
} }
startPublished := make(chan struct{}) startPublished := make(chan struct{})
@ -387,26 +390,19 @@ func TestAcquireJob(t *testing.T) {
// an import version job that we need to ignore. // an import version job that we need to ignore.
job, err = tc.acquire(ctx, srv) job, err = tc.acquire(ctx, srv)
require.NoError(t, err) require.NoError(t, err)
if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok { if job, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok {
// In the case of a prebuild claim, there is a second build, which is the
// one that we're interested in.
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM &&
job.WorkspaceBuild.Metadata.PrebuiltWorkspaceBuildStage != prebuiltWorkspaceBuildStage {
continue
}
break break
} }
} }
<-startPublished <-startPublished
if prebuiltWorkspaceBuildStage == sdkproto.PrebuiltWorkspaceBuildStage_CLAIM {
for {
// In the case of a prebuild claim, there is a second build, which is the
// one that we're interested in.
job, err = tc.acquire(ctx, srv)
require.NoError(t, err)
if _, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_); ok {
break
}
}
<-startPublished
}
got, err := json.Marshal(job.Type) got, err := json.Marshal(job.Type)
require.NoError(t, err) require.NoError(t, err)
@ -480,26 +476,29 @@ func TestAcquireJob(t *testing.T) {
require.JSONEq(t, string(want), string(got)) require.JSONEq(t, string(want), string(got))
// Assert that we delete the session token whenever stopbuildID := uuid.New()
// a stop is issued. stopJob := dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
stopbuild := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ ID: stopbuildID,
WorkspaceID: workspace.ID,
BuildNumber: 2,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStop,
Reason: database.BuildReasonInitiator,
})
_ = dbgen.ProvisionerJob(t, db, ps, database.ProvisionerJob{
ID: stopbuild.ID,
InitiatorID: user.ID, InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID, FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{ Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: stopbuild.ID, WorkspaceBuildID: stopbuildID,
})), })),
Tags: pd.Tags,
})
// Assert that we delete the session token whenever
// a stop is issued.
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
ID: stopbuildID,
WorkspaceID: workspace.ID,
BuildNumber: 3,
JobID: stopJob.ID,
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStop,
Reason: database.BuildReasonInitiator,
}) })
stopPublished := make(chan struct{}) stopPublished := make(chan struct{})
@ -534,7 +533,7 @@ func TestAcquireJob(t *testing.T) {
} }
t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) { t.Run(tc.name+"_TemplateVersionDryRun", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, ps, _ := setup(t, false, nil) srv, db, ps, pd := setup(t, false, nil)
ctx := context.Background() ctx := context.Background()
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
@ -550,6 +549,7 @@ func TestAcquireJob(t *testing.T) {
TemplateVersionID: version.ID, TemplateVersionID: version.ID,
WorkspaceName: "testing", WorkspaceName: "testing",
})), })),
Tags: pd.Tags,
}) })
job, err := tc.acquire(ctx, srv) job, err := tc.acquire(ctx, srv)
@ -579,7 +579,7 @@ func TestAcquireJob(t *testing.T) {
}) })
t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) { t.Run(tc.name+"_TemplateVersionImport", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, ps, _ := setup(t, false, nil) srv, db, ps, pd := setup(t, false, nil)
ctx := context.Background() ctx := context.Background()
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
@ -590,6 +590,7 @@ func TestAcquireJob(t *testing.T) {
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport, Type: database.ProvisionerJobTypeTemplateVersionImport,
Tags: pd.Tags,
}) })
job, err := tc.acquire(ctx, srv) job, err := tc.acquire(ctx, srv)
@ -611,7 +612,7 @@ func TestAcquireJob(t *testing.T) {
}) })
t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) { t.Run(tc.name+"_TemplateVersionImportWithUserVariable", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, ps, _ := setup(t, false, nil) srv, db, ps, pd := setup(t, false, nil)
user := dbgen.User(t, db, database.User{}) user := dbgen.User(t, db, database.User{})
version := dbgen.TemplateVersion(t, db, database.TemplateVersion{}) version := dbgen.TemplateVersion(t, db, database.TemplateVersion{})
@ -628,6 +629,7 @@ func TestAcquireJob(t *testing.T) {
{Name: "first", Value: "first_value"}, {Name: "first", Value: "first_value"},
}, },
})), })),
Tags: pd.Tags,
}) })
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
@ -674,12 +676,14 @@ func TestUpdateJob(t *testing.T) {
}) })
t.Run("NotRunning", func(t *testing.T) { t.Run("NotRunning", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _, _ := setup(t, false, nil) srv, db, _, pd := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun, Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
@ -690,12 +694,14 @@ func TestUpdateJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own! // This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) { t.Run("NotOwner", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _, _ := setup(t, false, nil) srv, db, _, pd := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionDryRun, Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -704,6 +710,11 @@ func TestUpdateJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{ _, err = srv.UpdateJob(ctx, &proto.UpdateJobRequest{
@ -712,12 +723,14 @@ func TestUpdateJob(t *testing.T) {
require.ErrorContains(t, err, "you don't own this job") require.ErrorContains(t, err, "you don't own this job")
}) })
setupJob := func(t *testing.T, db database.Store, srvID uuid.UUID) uuid.UUID { setupJob := func(t *testing.T, db database.Store, srvID uuid.UUID, tags database.StringMap) uuid.UUID {
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
Type: database.ProvisionerJobTypeTemplateVersionImport, Type: database.ProvisionerJobTypeTemplateVersionImport,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Input: json.RawMessage("{}"),
Tags: tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -726,6 +739,11 @@ func TestUpdateJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
return job.ID return job.ID
@ -734,7 +752,7 @@ func TestUpdateJob(t *testing.T) {
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _, pd := setup(t, false, &overrides{}) srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID) job := setupJob(t, db, pd.ID, pd.Tags)
_, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{ _, err := srv.UpdateJob(ctx, &proto.UpdateJobRequest{
JobId: job.String(), JobId: job.String(),
}) })
@ -744,7 +762,7 @@ func TestUpdateJob(t *testing.T) {
t.Run("Logs", func(t *testing.T) { t.Run("Logs", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, ps, pd := setup(t, false, &overrides{}) srv, db, ps, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID) job := setupJob(t, db, pd.ID, pd.Tags)
published := make(chan struct{}) published := make(chan struct{})
@ -769,7 +787,7 @@ func TestUpdateJob(t *testing.T) {
t.Run("Readme", func(t *testing.T) { t.Run("Readme", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _, pd := setup(t, false, &overrides{}) srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID) job := setupJob(t, db, pd.ID, pd.Tags)
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID, ID: versionID,
@ -795,7 +813,7 @@ func TestUpdateJob(t *testing.T) {
defer cancel() defer cancel()
srv, db, _, pd := setup(t, false, &overrides{}) srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID) job := setupJob(t, db, pd.ID, pd.Tags)
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID, ID: versionID,
@ -842,7 +860,7 @@ func TestUpdateJob(t *testing.T) {
defer cancel() defer cancel()
srv, db, _, pd := setup(t, false, &overrides{}) srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID) job := setupJob(t, db, pd.ID, pd.Tags)
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID, ID: versionID,
@ -888,7 +906,7 @@ func TestUpdateJob(t *testing.T) {
defer cancel() defer cancel()
srv, db, _, pd := setup(t, false, &overrides{}) srv, db, _, pd := setup(t, false, &overrides{})
job := setupJob(t, db, pd.ID) job := setupJob(t, db, pd.ID, pd.Tags)
versionID := uuid.New() versionID := uuid.New()
err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{ err := db.InsertTemplateVersion(ctx, database.InsertTemplateVersionParams{
ID: versionID, ID: versionID,
@ -933,12 +951,14 @@ func TestFailJob(t *testing.T) {
// This test prevents runners from updating jobs they don't own! // This test prevents runners from updating jobs they don't own!
t.Run("NotOwner", func(t *testing.T) { t.Run("NotOwner", func(t *testing.T) {
t.Parallel() t.Parallel()
srv, db, _, _ := setup(t, false, nil) srv, db, _, pd := setup(t, false, nil)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: uuid.New(), ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport, Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -947,6 +967,11 @@ func TestFailJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = srv.FailJob(ctx, &proto.FailedJob{ _, err = srv.FailJob(ctx, &proto.FailedJob{
@ -962,6 +987,8 @@ func TestFailJob(t *testing.T) {
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
Type: database.ProvisionerJobTypeTemplateVersionImport, Type: database.ProvisionerJobTypeTemplateVersionImport,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -970,6 +997,11 @@ func TestFailJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{ err = db.UpdateProvisionerJobWithCompleteByID(ctx, database.UpdateProvisionerJobWithCompleteByIDParams{
@ -996,11 +1028,19 @@ func TestFailJob(t *testing.T) {
auditor: auditor, auditor: auditor,
}) })
org := dbgen.Organization(t, db, database.Organization{}) org := dbgen.Organization(t, db, database.Organization{})
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ u := dbgen.User(t, db, database.User{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: u.ID,
})
workspace, err := db.InsertWorkspace(ctx, database.InsertWorkspaceParams{
ID: uuid.New(), ID: uuid.New(),
AutomaticUpdates: database.AutomaticUpdatesNever, AutomaticUpdates: database.AutomaticUpdatesNever,
OrganizationID: org.ID, OrganizationID: org.ID,
TemplateID: tpl.ID,
OwnerID: u.ID,
}) })
require.NoError(t, err)
buildID := uuid.New() buildID := uuid.New()
input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{ input, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: buildID, WorkspaceBuildID: buildID,
@ -1014,6 +1054,7 @@ func TestFailJob(t *testing.T) {
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
err = db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{ err = db.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{
@ -1032,6 +1073,11 @@ func TestFailJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1106,6 +1152,8 @@ func TestCompleteJob(t *testing.T) {
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -1115,6 +1163,11 @@ func TestCompleteJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = srv.CompleteJob(ctx, &proto.CompletedJob{ _, err = srv.CompleteJob(ctx, &proto.CompletedJob{
@ -1145,6 +1198,7 @@ func TestCompleteJob(t *testing.T) {
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeTemplateVersionImport, Type: database.ProvisionerJobTypeTemplateVersionImport,
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -1153,7 +1207,9 @@ func TestCompleteJob(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1192,6 +1248,8 @@ func TestCompleteJob(t *testing.T) {
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
Type: database.ProvisionerJobTypeTemplateVersionDryRun, Type: database.ProvisionerJobTypeTemplateVersionDryRun,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -1199,7 +1257,9 @@ func TestCompleteJob(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1275,7 +1335,9 @@ func TestCompleteJob(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1306,7 +1368,7 @@ func TestCompleteJob(t *testing.T) {
}}, }},
Timings: []*sdkproto.Timing{ Timings: []*sdkproto.Timing{
{ {
Stage: "test", Stage: "init",
Source: "test-source", Source: "test-source",
Resource: "test-resource", Resource: "test-resource",
Action: "test-action", Action: "test-action",
@ -1314,7 +1376,7 @@ func TestCompleteJob(t *testing.T) {
End: timestamppb.Now(), End: timestamppb.Now(),
}, },
{ {
Stage: "test2", Stage: "plan",
Source: "test-source2", Source: "test-source2",
Resource: "test-resource2", Resource: "test-resource2",
Action: "test-action2", Action: "test-action2",
@ -1383,7 +1445,7 @@ func TestCompleteJob(t *testing.T) {
timings, err := db.GetProvisionerJobTimingsByJobID(ctx, job.ID) timings, err := db.GetProvisionerJobTimingsByJobID(ctx, job.ID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, timings, 1, "Expected one timing entry to be created") require.Len(t, timings, 1, "Expected one timing entry to be created")
require.Equal(t, "test", string(timings[0].Stage), "Timing stage should match what was sent") require.Equal(t, "init", string(timings[0].Stage), "Timing stage should match what was sent")
}) })
}) })
@ -1405,6 +1467,7 @@ func TestCompleteJob(t *testing.T) {
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -1413,7 +1476,9 @@ func TestCompleteJob(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1459,6 +1524,7 @@ func TestCompleteJob(t *testing.T) {
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -1468,6 +1534,11 @@ func TestCompleteJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
completeJob := func() { completeJob := func() {
@ -1517,6 +1588,7 @@ func TestCompleteJob(t *testing.T) {
Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`), Input: []byte(`{"template_version_id": "` + versionID.String() + `"}`),
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild, Type: database.ProvisionerJobTypeWorkspaceBuild,
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -1526,6 +1598,11 @@ func TestCompleteJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
completeJob := func() { completeJob := func() {
@ -1749,6 +1826,11 @@ func TestCompleteJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: c.now,
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1830,6 +1912,8 @@ func TestCompleteJob(t *testing.T) {
Provisioner: database.ProvisionerTypeEcho, Provisioner: database.ProvisionerTypeEcho,
Type: database.ProvisionerJobTypeTemplateVersionDryRun, Type: database.ProvisionerJobTypeTemplateVersionDryRun,
StorageMethod: database.ProvisionerStorageMethodFile, StorageMethod: database.ProvisionerStorageMethodFile,
Input: json.RawMessage("{}"),
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
@ -1838,6 +1922,11 @@ func TestCompleteJob(t *testing.T) {
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
StartedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
ProvisionerTags: must(json.Marshal(job.Tags)),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1916,7 +2005,8 @@ func TestCompleteJob(t *testing.T) {
Transition: database.WorkspaceTransitionStart, Transition: database.WorkspaceTransitionStart,
}}, }},
provisionerJobParams: database.InsertProvisionerJobParams{ provisionerJobParams: database.InsertProvisionerJobParams{
Type: database.ProvisionerJobTypeTemplateVersionDryRun, Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: json.RawMessage("{}"),
}, },
}, },
{ {
@ -2060,6 +2150,10 @@ func TestCompleteJob(t *testing.T) {
if jobParams.StorageMethod == "" { if jobParams.StorageMethod == "" {
jobParams.StorageMethod = database.ProvisionerStorageMethodFile jobParams.StorageMethod = database.ProvisionerStorageMethodFile
} }
if jobParams.Tags == nil {
jobParams.Tags = pd.Tags
}
user := dbgen.User(t, db, database.User{})
job, err := db.InsertProvisionerJob(ctx, jobParams) job, err := db.InsertProvisionerJob(ctx, jobParams)
tpl := dbgen.Template(t, db, database.Template{ tpl := dbgen.Template(t, db, database.Template{
@ -2070,7 +2164,9 @@ func TestCompleteJob(t *testing.T) {
JobID: job.ID, JobID: job.ID,
}) })
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
TemplateID: tpl.ID, TemplateID: tpl.ID,
OrganizationID: pd.OrganizationID,
OwnerID: user.ID,
}) })
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
ID: workspaceBuildID, ID: workspaceBuildID,
@ -2085,7 +2181,9 @@ func TestCompleteJob(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{jobParams.Provisioner}, Types: []database.ProvisionerType{jobParams.Provisioner},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -2176,22 +2274,34 @@ func TestCompleteJob(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)
job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ job, err := db.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
Input: input, ID: uuid.New(),
Provisioner: database.ProvisionerTypeEcho, CreatedAt: dbtime.Now(),
StorageMethod: database.ProvisionerStorageMethodFile, UpdatedAt: dbtime.Now(),
Type: database.ProvisionerJobTypeWorkspaceBuild, OrganizationID: pd.OrganizationID,
InitiatorID: uuid.New(),
Input: input,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Tags: pd.Tags,
}) })
require.NoError(t, err) require.NoError(t, err)
user := dbgen.User(t, db, database.User{})
tpl := dbgen.Template(t, db, database.Template{ tpl := dbgen.Template(t, db, database.Template{
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
CreatedBy: user.ID,
}) })
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true}, OrganizationID: pd.OrganizationID,
JobID: job.ID, TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
JobID: job.ID,
CreatedBy: user.ID,
}) })
workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ workspace := dbgen.Workspace(t, db, database.WorkspaceTable{
TemplateID: tpl.ID, TemplateID: tpl.ID,
OrganizationID: pd.OrganizationID,
OwnerID: user.ID,
}) })
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
ID: buildID, ID: buildID,
@ -2200,11 +2310,14 @@ func TestCompleteJob(t *testing.T) {
TemplateVersionID: tv.ID, TemplateVersionID: tv.ID,
}) })
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID,
WorkerID: uuid.NullUUID{ WorkerID: uuid.NullUUID{
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -2298,7 +2411,9 @@ func TestCompleteJob(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -2545,7 +2660,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
} }
t.Run("NoAgents", func(t *testing.T) { t.Run("NoAgents", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2558,7 +2674,9 @@ func TestInsertWorkspaceResource(t *testing.T) {
}) })
t.Run("InvalidAgentToken", func(t *testing.T) { t.Run("InvalidAgentToken", func(t *testing.T) {
t.Parallel() t.Parallel()
err := insert(dbmem.New(), uuid.New(), &sdkproto.Resource{ db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
err := insert(db, uuid.New(), &sdkproto.Resource{
Name: "something", Name: "something",
Type: "aws_instance", Type: "aws_instance",
Agents: []*sdkproto.Agent{{ Agents: []*sdkproto.Agent{{
@ -2572,7 +2690,9 @@ func TestInsertWorkspaceResource(t *testing.T) {
}) })
t.Run("DuplicateApps", func(t *testing.T) { t.Run("DuplicateApps", func(t *testing.T) {
t.Parallel() t.Parallel()
err := insert(dbmem.New(), uuid.New(), &sdkproto.Resource{ db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
err := insert(db, uuid.New(), &sdkproto.Resource{
Name: "something", Name: "something",
Type: "aws_instance", Type: "aws_instance",
Agents: []*sdkproto.Agent{{ Agents: []*sdkproto.Agent{{
@ -2585,7 +2705,10 @@ func TestInsertWorkspaceResource(t *testing.T) {
}}, }},
}) })
require.ErrorContains(t, err, `duplicate app slug, must be unique per template: "a"`) require.ErrorContains(t, err, `duplicate app slug, must be unique per template: "a"`)
err = insert(dbmem.New(), uuid.New(), &sdkproto.Resource{
db, _ = dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
err = insert(db, uuid.New(), &sdkproto.Resource{
Name: "something", Name: "something",
Type: "aws_instance", Type: "aws_instance",
Agents: []*sdkproto.Agent{{ Agents: []*sdkproto.Agent{{
@ -2604,7 +2727,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
}) })
t.Run("AppSlugInvalid", func(t *testing.T) { t.Run("AppSlugInvalid", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2642,7 +2766,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
}) })
t.Run("DuplicateAgentNames", func(t *testing.T) { t.Run("DuplicateAgentNames", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
// case-insensitive-unique // case-insensitive-unique
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
@ -2668,7 +2793,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
}) })
t.Run("AgentNameInvalid", func(t *testing.T) { t.Run("AgentNameInvalid", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2697,7 +2823,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
}) })
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2754,7 +2881,7 @@ func TestInsertWorkspaceResource(t *testing.T) {
"else": "I laugh in the face of danger.", "else": "I laugh in the face of danger.",
}) })
require.NoError(t, err) require.NoError(t, err)
got, err := agent.EnvironmentVariables.RawMessage.MarshalJSON() got, err := json.Marshal(agent.EnvironmentVariables.RawMessage)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, want, got) require.Equal(t, want, got)
require.ElementsMatch(t, []database.DisplayApp{ require.ElementsMatch(t, []database.DisplayApp{
@ -2766,7 +2893,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
t.Run("AllDisplayApps", func(t *testing.T) { t.Run("AllDisplayApps", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2795,7 +2923,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
t.Run("DisableDefaultApps", func(t *testing.T) { t.Run("DisableDefaultApps", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2820,7 +2949,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
t.Run("ResourcesMonitoring", func(t *testing.T) { t.Run("ResourcesMonitoring", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2872,7 +3002,8 @@ func TestInsertWorkspaceResource(t *testing.T) {
t.Run("Devcontainers", func(t *testing.T) { t.Run("Devcontainers", func(t *testing.T) {
t.Parallel() t.Parallel()
db := dbmem.New() db, _ := dbtestutil.NewDB(t)
dbtestutil.DisableForeignKeysAndTriggers(t, db)
job := uuid.New() job := uuid.New()
err := insert(db, job, &sdkproto.Resource{ err := insert(db, job, &sdkproto.Resource{
Name: "something", Name: "something",
@ -2894,6 +3025,9 @@ func TestInsertWorkspaceResource(t *testing.T) {
require.Len(t, agents, 1) require.Len(t, agents, 1)
agent := agents[0] agent := agents[0]
devcontainers, err := db.GetWorkspaceAgentDevcontainersByAgentID(ctx, agent.ID) devcontainers, err := db.GetWorkspaceAgentDevcontainersByAgentID(ctx, agent.ID)
sort.Slice(devcontainers, func(i, j int) bool {
return devcontainers[i].Name > devcontainers[j].Name
})
require.NoError(t, err) require.NoError(t, err)
require.Len(t, devcontainers, 2) require.Len(t, devcontainers, 2)
require.Equal(t, "foo", devcontainers[0].Name) require.Equal(t, "foo", devcontainers[0].Name)
@ -2989,6 +3123,8 @@ func TestNotifications(t *testing.T) {
WorkspaceBuildID: build.ID, WorkspaceBuildID: build.ID,
})), })),
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}) })
_, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err = db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
@ -2996,7 +3132,9 @@ func TestNotifications(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -3109,6 +3247,8 @@ func TestNotifications(t *testing.T) {
WorkspaceBuildID: build.ID, WorkspaceBuildID: build.ID,
})), })),
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}) })
_, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
@ -3116,7 +3256,9 @@ func TestNotifications(t *testing.T) {
UUID: pd.ID, UUID: pd.ID,
Valid: true, Valid: true,
}, },
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -3182,9 +3324,11 @@ func TestNotifications(t *testing.T) {
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
}) })
_, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{ _, err := db.AcquireProvisionerJob(ctx, database.AcquireProvisionerJobParams{
OrganizationID: pd.OrganizationID, OrganizationID: pd.OrganizationID,
WorkerID: uuid.NullUUID{UUID: pd.ID, Valid: true}, WorkerID: uuid.NullUUID{UUID: pd.ID, Valid: true},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho}, Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(job.Tags)),
StartedAt: sql.NullTime{Time: job.CreatedAt, Valid: true},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -3230,8 +3374,8 @@ type overrides struct {
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) { func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub, database.ProvisionerDaemon) {
t.Helper() t.Helper()
logger := testutil.Logger(t) logger := testutil.Logger(t)
db := dbmem.New() db, ps := dbtestutil.NewDB(t)
ps := pubsub.NewInMemory() dbtestutil.DisableForeignKeysAndTriggers(t, db)
defOrg, err := db.GetDefaultOrganization(context.Background()) defOrg, err := db.GetDefaultOrganization(context.Background())
require.NoError(t, err, "default org not found") require.NoError(t, err, "default org not found")