feat: support prebuilt workspaces in non-default organizations (#18010)

closes https://github.com/coder/internal/issues/527
This commit is contained in:
Sas Swart
2025-06-04 14:20:29 +02:00
committed by GitHub
parent 4d0fe20ca6
commit 5f7e5d7097
8 changed files with 598 additions and 324 deletions

View File

@ -412,6 +412,21 @@ var (
policy.ActionCreate, policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionCreate, policy.ActionDelete, policy.ActionRead, policy.ActionUpdate,
policy.ActionWorkspaceStart, policy.ActionWorkspaceStop, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop,
}, },
// Should be able to add the prebuilds system user as a member to any organization that needs prebuilds.
rbac.ResourceOrganizationMember.Type: {
policy.ActionCreate,
},
// Needs to be able to assign roles to the system user in order to make it a member of an organization.
rbac.ResourceAssignOrgRole.Type: {
policy.ActionAssign,
},
// Needs to be able to read users to determine which organizations the prebuild system user is a member of.
rbac.ResourceUser.Type: {
policy.ActionRead,
},
rbac.ResourceOrganization.Type: {
policy.ActionRead,
},
}), }),
}, },
}), }),

View File

@ -33,6 +33,8 @@ const (
orgUserAdmin string = "organization-user-admin" orgUserAdmin string = "organization-user-admin"
orgTemplateAdmin string = "organization-template-admin" orgTemplateAdmin string = "organization-template-admin"
orgWorkspaceCreationBan string = "organization-workspace-creation-ban" orgWorkspaceCreationBan string = "organization-workspace-creation-ban"
prebuildsOrchestrator string = "prebuilds-orchestrator"
) )
func init() { func init() {
@ -599,6 +601,9 @@ var assignRoles = map[string]map[string]bool{
orgUserAdmin: { orgUserAdmin: {
orgMember: true, orgMember: true,
}, },
prebuildsOrchestrator: {
orgMember: true,
},
} }
// ExpandableRoles is any type that can be expanded into a []Role. This is implemented // ExpandableRoles is any type that can be expanded into a []Role. This is implemented

View File

@ -19,13 +19,16 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtestutil"
agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/enterprise/coderd/prebuilds" "github.com/coder/coder/v2/enterprise/coderd/prebuilds"
"github.com/coder/coder/v2/provisioner/echo" "github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -105,7 +108,6 @@ func TestClaimPrebuild(t *testing.T) {
expectPrebuildClaimed: true, expectPrebuildClaimed: true,
markPrebuildsClaimable: true, markPrebuildsClaimable: true,
}, },
"no claimable prebuilt workspaces error is returned": { "no claimable prebuilt workspaces error is returned": {
expectPrebuildClaimed: false, expectPrebuildClaimed: false,
markPrebuildsClaimable: true, markPrebuildsClaimable: true,
@ -124,227 +126,248 @@ func TestClaimPrebuild(t *testing.T) {
} }
for name, tc := range cases { for name, tc := range cases {
tc := tc // Ensure that prebuilt workspaces can be claimed in non-default organizations:
for _, useDefaultOrg := range []bool{true, false} {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
t.Run(name, func(t *testing.T) { // Setup.
t.Parallel() ctx := testutil.Context(t, testutil.WaitSuperLong)
db, pubsub := dbtestutil.NewDB(t)
// Setup. spy := newStoreSpy(db, tc.claimingErr)
ctx := testutil.Context(t, testutil.WaitSuperLong) expectedPrebuildsCount := desiredInstances * presetCount
db, pubsub := dbtestutil.NewDB(t)
spy := newStoreSpy(db, tc.claimingErr) logger := testutil.Logger(t)
expectedPrebuildsCount := desiredInstances * presetCount client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: spy,
Pubsub: pubsub,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
logger := testutil.Logger(t) EntitlementsUpdateInterval: time.Second,
client, _, api, owner := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ })
Options: &coderdtest.Options{
IncludeProvisionerDaemon: true,
Database: spy,
Pubsub: pubsub,
},
EntitlementsUpdateInterval: time.Second, orgID := owner.OrganizationID
}) if !useDefaultOrg {
secondOrg := dbgen.Organization(t, db, database.Organization{})
reconciler := prebuilds.NewStoreReconciler(spy, pubsub, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) orgID = secondOrg.ID
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy)
api.AGPL.PrebuildsClaimer.Store(&claimer)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, templateWithAgentAndPresetsWithPrebuilds(desiredInstances))
_ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
presets, err := client.TemplateVersionPresets(ctx, version.ID)
require.NoError(t, err)
require.Len(t, presets, presetCount)
userClient, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleMember())
// Given: the reconciliation state is snapshot.
state, err := reconciler.SnapshotState(ctx, spy)
require.NoError(t, err)
require.Len(t, state.Presets, presetCount)
// When: a reconciliation is setup for each preset.
for _, preset := range presets {
ps, err := state.FilterByPreset(preset.ID)
require.NoError(t, err)
require.NotNil(t, ps)
actions, err := reconciler.CalculateActions(ctx, *ps)
require.NoError(t, err)
require.NotNil(t, actions)
require.NoError(t, reconciler.ReconcilePreset(ctx, *ps))
}
// Given: a set of running, eligible prebuilds eventually starts up.
runningPrebuilds := make(map[uuid.UUID]database.GetRunningPrebuiltWorkspacesRow, desiredInstances*presetCount)
require.Eventually(t, func() bool {
rows, err := spy.GetRunningPrebuiltWorkspaces(ctx)
if err != nil {
return false
} }
for _, row := range rows { provisionerCloser := coderdenttest.NewExternalProvisionerDaemon(t, client, orgID, map[string]string{
runningPrebuilds[row.CurrentPresetID.UUID] = row provisionersdk.TagScope: provisionersdk.ScopeOrganization,
})
defer provisionerCloser.Close()
if !tc.markPrebuildsClaimable { reconciler := prebuilds.NewStoreReconciler(spy, pubsub, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
continue var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy)
} api.AGPL.PrebuildsClaimer.Store(&claimer)
agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, row.ID) version := coderdtest.CreateTemplateVersion(t, client, orgID, templateWithAgentAndPresetsWithPrebuilds(desiredInstances))
if err != nil { _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
return false coderdtest.CreateTemplate(t, client, orgID, version.ID)
} presets, err := client.TemplateVersionPresets(ctx, version.ID)
// Workspaces are eligible once its agent is marked "ready".
for _, agent := range agents {
err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: sql.NullTime{Time: time.Now().Add(time.Hour), Valid: true},
ReadyAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
})
if err != nil {
return false
}
}
}
t.Logf("found %d running prebuilds so far, want %d", len(runningPrebuilds), expectedPrebuildsCount)
return len(runningPrebuilds) == expectedPrebuildsCount
}, testutil.WaitSuperLong, testutil.IntervalSlow)
// When: a user creates a new workspace with a preset for which prebuilds are configured.
workspaceName := strings.ReplaceAll(testutil.GetRandomName(t), "_", "-")
params := database.ClaimPrebuiltWorkspaceParams{
NewUserID: user.ID,
NewName: workspaceName,
PresetID: presets[0].ID,
}
userWorkspace, err := userClient.CreateUserWorkspace(ctx, user.Username, codersdk.CreateWorkspaceRequest{
TemplateVersionID: version.ID,
Name: workspaceName,
TemplateVersionPresetID: presets[0].ID,
})
isNoPrebuiltWorkspaces := errors.Is(tc.claimingErr, agplprebuilds.ErrNoClaimablePrebuiltWorkspaces)
isUnsupported := errors.Is(tc.claimingErr, agplprebuilds.ErrAGPLDoesNotSupportPrebuiltWorkspaces)
switch {
case tc.claimingErr != nil && (isNoPrebuiltWorkspaces || isUnsupported):
require.NoError(t, err) require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID) require.Len(t, presets, presetCount)
// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed and we fallback to creating new workspace. userClient, user := coderdtest.CreateAnotherUser(t, client, orgID, rbac.RoleMember())
currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
return
case tc.claimingErr != nil && errors.Is(tc.claimingErr, unexpectedClaimingError):
// Then: unexpected error happened and was propagated all the way to the caller
require.Error(t, err)
require.ErrorContains(t, err, unexpectedClaimingError.Error())
// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed.
currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
return
default:
// tc.claimingErr is nil scenario
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID)
}
// at this point we know that tc.claimingErr is nil
// Then: a prebuild should have been claimed.
require.EqualValues(t, spy.claims.Load(), 1)
require.EqualValues(t, *spy.claimParams.Load(), params)
if !tc.expectPrebuildClaimed {
require.Nil(t, spy.claimedWorkspace.Load())
return
}
require.NotNil(t, spy.claimedWorkspace.Load())
claimed := *spy.claimedWorkspace.Load()
require.NotEqual(t, claimed.ID, uuid.Nil)
// Then: the claimed prebuild must now be owned by the requester.
workspace, err := spy.GetWorkspaceByID(ctx, claimed.ID)
require.NoError(t, err)
require.Equal(t, user.ID, workspace.OwnerID)
// Then: the number of running prebuilds has changed since one was claimed.
currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount-1, len(currentPrebuilds))
// Then: the claimed prebuild is now missing from the running prebuilds set.
found := slices.ContainsFunc(currentPrebuilds, func(prebuild database.GetRunningPrebuiltWorkspacesRow) bool {
return prebuild.ID == claimed.ID
})
require.False(t, found, "claimed prebuild should not still be considered a running prebuild")
// Then: reconciling at this point will provision a new prebuild to replace the claimed one.
{
// Given: the reconciliation state is snapshot. // Given: the reconciliation state is snapshot.
state, err = reconciler.SnapshotState(ctx, spy) state, err := reconciler.SnapshotState(ctx, spy)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, state.Presets, presetCount)
// When: a reconciliation is setup for each preset. // When: a reconciliation is setup for each preset.
for _, preset := range presets { for _, preset := range presets {
ps, err := state.FilterByPreset(preset.ID) ps, err := state.FilterByPreset(preset.ID)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, ps)
actions, err := reconciler.CalculateActions(ctx, *ps)
require.NoError(t, err)
require.NotNil(t, actions)
// Then: the reconciliation takes place without error.
require.NoError(t, reconciler.ReconcilePreset(ctx, *ps)) require.NoError(t, reconciler.ReconcilePreset(ctx, *ps))
} }
}
require.Eventually(t, func() bool { // Given: a set of running, eligible prebuilds eventually starts up.
rows, err := spy.GetRunningPrebuiltWorkspaces(ctx) runningPrebuilds := make(map[uuid.UUID]database.GetRunningPrebuiltWorkspacesRow, desiredInstances*presetCount)
if err != nil { require.Eventually(t, func() bool {
return false rows, err := spy.GetRunningPrebuiltWorkspaces(ctx)
if err != nil {
return false
}
for _, row := range rows {
runningPrebuilds[row.CurrentPresetID.UUID] = row
if !tc.markPrebuildsClaimable {
continue
}
agents, err := db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, row.ID)
if err != nil {
return false
}
// Workspaces are eligible once its agent is marked "ready".
for _, agent := range agents {
err = db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: agent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: sql.NullTime{Time: time.Now().Add(time.Hour), Valid: true},
ReadyAt: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
})
if err != nil {
return false
}
}
}
t.Logf("found %d running prebuilds so far, want %d", len(runningPrebuilds), expectedPrebuildsCount)
return len(runningPrebuilds) == expectedPrebuildsCount
}, testutil.WaitSuperLong, testutil.IntervalSlow)
// When: a user creates a new workspace with a preset for which prebuilds are configured.
workspaceName := strings.ReplaceAll(testutil.GetRandomName(t), "_", "-")
params := database.ClaimPrebuiltWorkspaceParams{
NewUserID: user.ID,
NewName: workspaceName,
PresetID: presets[0].ID,
}
userWorkspace, err := userClient.CreateUserWorkspace(ctx, user.Username, codersdk.CreateWorkspaceRequest{
TemplateVersionID: version.ID,
Name: workspaceName,
TemplateVersionPresetID: presets[0].ID,
})
isNoPrebuiltWorkspaces := errors.Is(tc.claimingErr, agplprebuilds.ErrNoClaimablePrebuiltWorkspaces)
isUnsupported := errors.Is(tc.claimingErr, agplprebuilds.ErrAGPLDoesNotSupportPrebuiltWorkspaces)
switch {
case tc.claimingErr != nil && (isNoPrebuiltWorkspaces || isUnsupported):
require.NoError(t, err)
build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID)
_ = build
// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed and we fallback to creating new workspace.
currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
return
case tc.claimingErr != nil && errors.Is(tc.claimingErr, unexpectedClaimingError):
// Then: unexpected error happened and was propagated all the way to the caller
require.Error(t, err)
require.ErrorContains(t, err, unexpectedClaimingError.Error())
// Then: the number of running prebuilds hasn't changed because claiming prebuild is failed.
currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount, len(currentPrebuilds))
return
default:
// tc.claimingErr is nil scenario
require.NoError(t, err)
build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, userWorkspace.LatestBuild.ID)
require.Equal(t, build.Job.Status, codersdk.ProvisionerJobSucceeded)
} }
t.Logf("found %d running prebuilds so far, want %d", len(rows), expectedPrebuildsCount) // at this point we know that tc.claimingErr is nil
return len(runningPrebuilds) == expectedPrebuildsCount // Then: a prebuild should have been claimed.
}, testutil.WaitSuperLong, testutil.IntervalSlow) require.EqualValues(t, spy.claims.Load(), 1)
require.EqualValues(t, *spy.claimParams.Load(), params)
// Then: when restarting the created workspace (which claimed a prebuild), it should not try and claim a new prebuild. if !tc.expectPrebuildClaimed {
// Prebuilds should ONLY be used for net-new workspaces. require.Nil(t, spy.claimedWorkspace.Load())
// This is expected by default anyway currently since new workspaces and operations on existing workspaces return
// take different code paths, but it's worth validating. }
spy.claims.Store(0) // Reset counter because we need to check if any new claim requests happen. require.NotNil(t, spy.claimedWorkspace.Load())
claimed := *spy.claimedWorkspace.Load()
require.NotEqual(t, claimed.ID, uuid.Nil)
wp, err := userClient.WorkspaceBuildParameters(ctx, userWorkspace.LatestBuild.ID) // Then: the claimed prebuild must now be owned by the requester.
require.NoError(t, err) workspace, err := spy.GetWorkspaceByID(ctx, claimed.ID)
require.NoError(t, err)
require.Equal(t, user.ID, workspace.OwnerID)
stopBuild, err := userClient.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ // Then: the number of running prebuilds has changed since one was claimed.
TemplateVersionID: version.ID, currentPrebuilds, err := spy.GetRunningPrebuiltWorkspaces(ctx)
Transition: codersdk.WorkspaceTransitionStop, require.NoError(t, err)
require.Equal(t, expectedPrebuildsCount-1, len(currentPrebuilds))
// Then: the claimed prebuild is now missing from the running prebuilds set.
found := slices.ContainsFunc(currentPrebuilds, func(prebuild database.GetRunningPrebuiltWorkspacesRow) bool {
return prebuild.ID == claimed.ID
})
require.False(t, found, "claimed prebuild should not still be considered a running prebuild")
// Then: reconciling at this point will provision a new prebuild to replace the claimed one.
{
// Given: the reconciliation state is snapshot.
state, err = reconciler.SnapshotState(ctx, spy)
require.NoError(t, err)
// When: a reconciliation is setup for each preset.
for _, preset := range presets {
ps, err := state.FilterByPreset(preset.ID)
require.NoError(t, err)
// Then: the reconciliation takes place without error.
require.NoError(t, reconciler.ReconcilePreset(ctx, *ps))
}
}
require.Eventually(t, func() bool {
rows, err := spy.GetRunningPrebuiltWorkspaces(ctx)
if err != nil {
return false
}
t.Logf("found %d running prebuilds so far, want %d", len(rows), expectedPrebuildsCount)
return len(runningPrebuilds) == expectedPrebuildsCount
}, testutil.WaitSuperLong, testutil.IntervalSlow)
// Then: when restarting the created workspace (which claimed a prebuild), it should not try and claim a new prebuild.
// Prebuilds should ONLY be used for net-new workspaces.
// This is expected by default anyway currently since new workspaces and operations on existing workspaces
// take different code paths, but it's worth validating.
spy.claims.Store(0) // Reset counter because we need to check if any new claim requests happen.
wp, err := userClient.WorkspaceBuildParameters(ctx, userWorkspace.LatestBuild.ID)
require.NoError(t, err)
stopBuild, err := userClient.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: version.ID,
Transition: codersdk.WorkspaceTransitionStop,
})
require.NoError(t, err)
build := coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, stopBuild.ID)
require.Equal(t, build.Job.Status, codersdk.ProvisionerJobSucceeded)
startBuild, err := userClient.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: version.ID,
Transition: codersdk.WorkspaceTransitionStart,
RichParameterValues: wp,
})
require.NoError(t, err)
build = coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, startBuild.ID)
require.Equal(t, build.Job.Status, codersdk.ProvisionerJobSucceeded)
require.Zero(t, spy.claims.Load())
}) })
require.NoError(t, err) }
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, stopBuild.ID)
startBuild, err := userClient.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: version.ID,
Transition: codersdk.WorkspaceTransitionStart,
RichParameterValues: wp,
})
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, startBuild.ID)
require.Zero(t, spy.claims.Load())
})
} }
} }

View File

@ -0,0 +1,81 @@
package prebuilds
import (
"context"
"database/sql"
"errors"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/quartz"
)
// StoreMembershipReconciler encapsulates the responsibility of ensuring that the prebuilds system user is a member of all
// organizations for which prebuilt workspaces are requested. This is necessary because our data model requires that such
// prebuilt workspaces belong to a member of the organization of their eventual claimant.
type StoreMembershipReconciler struct {
store database.Store
clock quartz.Clock
}
func NewStoreMembershipReconciler(store database.Store, clock quartz.Clock) StoreMembershipReconciler {
return StoreMembershipReconciler{
store: store,
clock: clock,
}
}
// ReconcileAll compares the current membership of a user to the membership required in order to create prebuilt workspaces.
// If the user in question is not yet a member of an organization that needs prebuilt workspaces, ReconcileAll will create
// the membership required.
//
// This method does not have an opinion on transaction or lock management. These responsibilities are left to the caller.
func (s StoreMembershipReconciler) ReconcileAll(ctx context.Context, userID uuid.UUID, presets []database.GetTemplatePresetsWithPrebuildsRow) error {
organizationMemberships, err := s.store.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
UserID: userID,
Deleted: sql.NullBool{
Bool: false,
Valid: true,
},
})
if err != nil {
return xerrors.Errorf("determine prebuild organization membership: %w", err)
}
systemUserMemberships := make(map[uuid.UUID]struct{}, 0)
defaultOrg, err := s.store.GetDefaultOrganization(ctx)
if err != nil {
return xerrors.Errorf("get default organization: %w", err)
}
systemUserMemberships[defaultOrg.ID] = struct{}{}
for _, o := range organizationMemberships {
systemUserMemberships[o.ID] = struct{}{}
}
var membershipInsertionErrors error
for _, preset := range presets {
_, alreadyMember := systemUserMemberships[preset.OrganizationID]
if alreadyMember {
continue
}
// Add the organization to our list of memberships regardless of potential failure below
// to avoid a retry that will probably be doomed anyway.
systemUserMemberships[preset.OrganizationID] = struct{}{}
// Insert the missing membership
_, err = s.store.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{
OrganizationID: preset.OrganizationID,
UserID: userID,
CreatedAt: s.clock.Now(),
UpdatedAt: s.clock.Now(),
Roles: []string{},
})
if err != nil {
membershipInsertionErrors = errors.Join(membershipInsertionErrors, xerrors.Errorf("insert membership for prebuilt workspaces: %w", err))
continue
}
}
return membershipInsertionErrors
}

View File

@ -0,0 +1,127 @@
package prebuilds_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/quartz"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/enterprise/coderd/prebuilds"
)
// TestReconcileAll verifies that StoreMembershipReconciler correctly updates membership
// for the prebuilds system user.
func TestReconcileAll(t *testing.T) {
t.Parallel()
ctx := context.Background()
clock := quartz.NewMock(t)
// Helper to build a minimal Preset row belonging to a given org.
newPresetRow := func(orgID uuid.UUID) database.GetTemplatePresetsWithPrebuildsRow {
return database.GetTemplatePresetsWithPrebuildsRow{
ID: uuid.New(),
OrganizationID: orgID,
}
}
tests := []struct {
name string
includePreset bool
preExistingMembership bool
}{
// The StoreMembershipReconciler acts based on the provided agplprebuilds.GlobalSnapshot.
// These test cases must therefore trust any valid snapshot, so the only relevant functional test cases are:
// No presets to act on and the prebuilds user does not belong to any organizations.
// Reconciliation should be a no-op
{name: "no presets, no memberships", includePreset: false, preExistingMembership: false},
// If we have a preset that requires prebuilds, but the prebuilds user is not a member of
// that organization, then we should add the membership.
{name: "preset, but no membership", includePreset: true, preExistingMembership: false},
// If the prebuilds system user is already a member of the organization to which a preset belongs,
// then reconciliation should be a no-op:
{name: "preset, but already a member", includePreset: true, preExistingMembership: true},
// If the prebuilds system user is a member of an organization that doesn't have need any prebuilds,
// then it must have required prebuilds in the past. The membership is not currently necessary, but
// the reconciler won't remove it, because there's little cost to keeping it and prebuilds might be
// enabled again.
{name: "member, but no presets", includePreset: false, preExistingMembership: true},
}
for _, tc := range tests {
tc := tc // capture
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
defaultOrg, err := db.GetDefaultOrganization(ctx)
require.NoError(t, err)
// introduce an unrelated organization to ensure that the membership reconciler don't interfere with it.
unrelatedOrg := dbgen.Organization(t, db, database.Organization{})
targetOrg := dbgen.Organization(t, db, database.Organization{})
if !dbtestutil.WillUsePostgres() {
// dbmem doesn't ensure membership to the default organization
dbgen.OrganizationMember(t, db, database.OrganizationMember{
OrganizationID: defaultOrg.ID,
UserID: agplprebuilds.SystemUserID,
})
}
dbgen.OrganizationMember(t, db, database.OrganizationMember{OrganizationID: unrelatedOrg.ID, UserID: agplprebuilds.SystemUserID})
if tc.preExistingMembership {
// System user already a member of both orgs.
dbgen.OrganizationMember(t, db, database.OrganizationMember{OrganizationID: targetOrg.ID, UserID: agplprebuilds.SystemUserID})
}
presets := []database.GetTemplatePresetsWithPrebuildsRow{newPresetRow(unrelatedOrg.ID)}
if tc.includePreset {
presets = append(presets, newPresetRow(targetOrg.ID))
}
// Verify memberships before reconciliation.
preReconcileMemberships, err := db.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
UserID: agplprebuilds.SystemUserID,
})
require.NoError(t, err)
expectedMembershipsBefore := []uuid.UUID{defaultOrg.ID, unrelatedOrg.ID}
if tc.preExistingMembership {
expectedMembershipsBefore = append(expectedMembershipsBefore, targetOrg.ID)
}
require.ElementsMatch(t, expectedMembershipsBefore, extractOrgIDs(preReconcileMemberships))
// Reconcile
reconciler := prebuilds.NewStoreMembershipReconciler(db, clock)
require.NoError(t, reconciler.ReconcileAll(ctx, agplprebuilds.SystemUserID, presets))
// Verify memberships after reconciliation.
postReconcileMemberships, err := db.GetOrganizationsByUserID(ctx, database.GetOrganizationsByUserIDParams{
UserID: agplprebuilds.SystemUserID,
})
require.NoError(t, err)
expectedMembershipsAfter := expectedMembershipsBefore
if !tc.preExistingMembership && tc.includePreset {
expectedMembershipsAfter = append(expectedMembershipsAfter, targetOrg.ID)
}
require.ElementsMatch(t, expectedMembershipsAfter, extractOrgIDs(postReconcileMemberships))
})
}
}
func extractOrgIDs(orgs []database.Organization) []uuid.UUID {
ids := make([]uuid.UUID, len(orgs))
for i, o := range orgs {
ids[i] = o.ID
}
return ids
}

View File

@ -251,8 +251,8 @@ func (c *StoreReconciler) ReconcileAll(ctx context.Context) error {
logger.Debug(ctx, "starting reconciliation") logger.Debug(ctx, "starting reconciliation")
err := c.WithReconciliationLock(ctx, logger, func(ctx context.Context, db database.Store) error { err := c.WithReconciliationLock(ctx, logger, func(ctx context.Context, _ database.Store) error {
snapshot, err := c.SnapshotState(ctx, db) snapshot, err := c.SnapshotState(ctx, c.store)
if err != nil { if err != nil {
return xerrors.Errorf("determine current snapshot: %w", err) return xerrors.Errorf("determine current snapshot: %w", err)
} }
@ -264,6 +264,12 @@ func (c *StoreReconciler) ReconcileAll(ctx context.Context) error {
return nil return nil
} }
membershipReconciler := NewStoreMembershipReconciler(c.store, c.clock)
err = membershipReconciler.ReconcileAll(ctx, prebuilds.SystemUserID, snapshot.Presets)
if err != nil {
return xerrors.Errorf("reconcile prebuild membership: %w", err)
}
var eg errgroup.Group var eg errgroup.Group
// Reconcile presets in parallel. Each preset in its own goroutine. // Reconcile presets in parallel. Each preset in its own goroutine.
for _, preset := range snapshot.Presets { for _, preset := range snapshot.Presets {

View File

@ -43,7 +43,7 @@ func TestNoReconciliationActionsIfNoPresets(t *testing.T) {
t.Parallel() t.Parallel()
if !dbtestutil.WillUsePostgres() { if !dbtestutil.WillUsePostgres() {
t.Skip("This test requires postgres") t.Skip("dbmem times out on nesting transactions, postgres ignores the inner ones")
} }
clock := quartz.NewMock(t) clock := quartz.NewMock(t)
@ -88,7 +88,7 @@ func TestNoReconciliationActionsIfNoPrebuilds(t *testing.T) {
t.Parallel() t.Parallel()
if !dbtestutil.WillUsePostgres() { if !dbtestutil.WillUsePostgres() {
t.Skip("This test requires postgres") t.Skip("dbmem times out on nesting transactions, postgres ignores the inner ones")
} }
clock := quartz.NewMock(t) clock := quartz.NewMock(t)

View File

@ -11,7 +11,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/serpent" "github.com/coder/serpent"
"github.com/google/uuid" "github.com/google/uuid"
@ -84,8 +87,6 @@ func TestBlockNonBrowser(t *testing.T) {
func TestReinitializeAgent(t *testing.T) { func TestReinitializeAgent(t *testing.T) {
t.Parallel() t.Parallel()
tempAgentLog := testutil.CreateTemp(t, "", "testReinitializeAgent")
if !dbtestutil.WillUsePostgres() { if !dbtestutil.WillUsePostgres() {
t.Skip("dbmem cannot currently claim a workspace") t.Skip("dbmem cannot currently claim a workspace")
} }
@ -94,79 +95,98 @@ func TestReinitializeAgent(t *testing.T) {
t.Skip("test startup script is not supported on windows") t.Skip("test startup script is not supported on windows")
} }
startupScript := fmt.Sprintf("printenv >> %s; echo '---\n' >> %s", tempAgentLog.Name(), tempAgentLog.Name()) // Ensure that workspace agents can reinitialize against claimed prebuilds in non-default organizations:
for _, useDefaultOrg := range []bool{true, false} {
t.Run("", func(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t) tempAgentLog := testutil.CreateTemp(t, "", "testReinitializeAgent")
// GIVEN a live enterprise API with the prebuilds feature enabled
client, user := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
Database: db,
Pubsub: ps,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Prebuilds.ReconciliationInterval = serpent.Duration(time.Second)
dv.Experiments.Append(string(codersdk.ExperimentWorkspacePrebuilds))
}),
IncludeProvisionerDaemon: true,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureWorkspacePrebuilds: 1,
},
},
})
// GIVEN a template, template version, preset and a prebuilt workspace that uses them all startupScript := fmt.Sprintf("printenv >> %s; echo '---\n' >> %s", tempAgentLog.Name(), tempAgentLog.Name())
agentToken := uuid.UUID{3}
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ db, ps := dbtestutil.NewDB(t)
Parse: echo.ParseComplete, // GIVEN a live enterprise API with the prebuilds feature enabled
ProvisionPlan: []*proto.Response{ client, user := coderdenttest.New(t, &coderdenttest.Options{
{ Options: &coderdtest.Options{
Type: &proto.Response_Plan{ Database: db,
Plan: &proto.PlanComplete{ Pubsub: ps,
Presets: []*proto.Preset{ DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
{ dv.Prebuilds.ReconciliationInterval = serpent.Duration(time.Second)
Name: "test-preset", dv.Experiments.Append(string(codersdk.ExperimentWorkspacePrebuilds))
Prebuild: &proto.Prebuild{ }),
Instances: 1, },
}, LicenseOptions: &coderdenttest.LicenseOptions{
}, Features: license.Features{
}, codersdk.FeatureWorkspacePrebuilds: 1,
Resources: []*proto.Resource{ codersdk.FeatureExternalProvisionerDaemons: 1,
{
Agents: []*proto.Agent{
{
Name: "smith",
OperatingSystem: "linux",
Architecture: "i386",
},
},
},
},
}, },
}, },
}, })
},
ProvisionApply: []*proto.Response{ orgID := user.OrganizationID
{ if !useDefaultOrg {
Type: &proto.Response_Apply{ secondOrg := dbgen.Organization(t, db, database.Organization{})
Apply: &proto.ApplyComplete{ orgID = secondOrg.ID
Resources: []*proto.Resource{ }
{ provisionerCloser := coderdenttest.NewExternalProvisionerDaemon(t, client, orgID, map[string]string{
Type: "compute", provisionersdk.TagScope: provisionersdk.ScopeOrganization,
Name: "main", })
Agents: []*proto.Agent{ defer provisionerCloser.Close()
// GIVEN a template, template version, preset and a prebuilt workspace that uses them all
agentToken := uuid.UUID{3}
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{
Presets: []*proto.Preset{
{ {
Name: "smith", Name: "test-preset",
OperatingSystem: "linux", Prebuild: &proto.Prebuild{
Architecture: "i386", Instances: 1,
Scripts: []*proto.Script{ },
},
},
Resources: []*proto.Resource{
{
Agents: []*proto.Agent{
{ {
RunOnStart: true, Name: "smith",
Script: startupScript, OperatingSystem: "linux",
Architecture: "i386",
}, },
}, },
Auth: &proto.Agent_Token{ },
Token: agentToken.String(), },
},
},
},
},
ProvisionApply: []*proto.Response{
{
Type: &proto.Response_Apply{
Apply: &proto.ApplyComplete{
Resources: []*proto.Resource{
{
Type: "compute",
Name: "main",
Agents: []*proto.Agent{
{
Name: "smith",
OperatingSystem: "linux",
Architecture: "i386",
Scripts: []*proto.Script{
{
RunOnStart: true,
Script: startupScript,
},
},
Auth: &proto.Agent_Token{
Token: agentToken.String(),
},
},
}, },
}, },
}, },
@ -174,79 +194,76 @@ func TestReinitializeAgent(t *testing.T) {
}, },
}, },
}, },
}, })
}, coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) coderdtest.CreateTemplate(t, client, orgID, version.ID)
// Wait for prebuilds to create a prebuilt workspace // Wait for prebuilds to create a prebuilt workspace
ctx := context.Background() ctx := testutil.Context(t, testutil.WaitLong)
// ctx := testutil.Context(t, testutil.WaitLong) var prebuildID uuid.UUID
var ( require.Eventually(t, func() bool {
prebuildID uuid.UUID agentAndBuild, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, agentToken)
) if err != nil {
require.Eventually(t, func() bool { return false
agentAndBuild, err := db.GetWorkspaceAgentAndLatestBuildByAuthToken(ctx, agentToken) }
if err != nil { prebuildID = agentAndBuild.WorkspaceBuild.ID
return false return true
} }, testutil.WaitLong, testutil.IntervalFast)
prebuildID = agentAndBuild.WorkspaceBuild.ID
return true
}, testutil.WaitLong, testutil.IntervalFast)
prebuild := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, prebuildID) prebuild := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, prebuildID)
preset, err := db.GetPresetByWorkspaceBuildID(ctx, prebuildID) preset, err := db.GetPresetByWorkspaceBuildID(ctx, prebuildID)
require.NoError(t, err) require.NoError(t, err)
// GIVEN a running agent // GIVEN a running agent
logDir := t.TempDir() logDir := t.TempDir()
inv, _ := clitest.New(t, inv, _ := clitest.New(t,
"agent", "agent",
"--auth", "token", "--auth", "token",
"--agent-token", agentToken.String(), "--agent-token", agentToken.String(),
"--agent-url", client.URL.String(), "--agent-url", client.URL.String(),
"--log-dir", logDir, "--log-dir", logDir,
) )
clitest.Start(t, inv) clitest.Start(t, inv)
// GIVEN the agent is in a happy steady state // GIVEN the agent is in a happy steady state
waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, prebuild.WorkspaceID) waiter := coderdtest.NewWorkspaceAgentWaiter(t, client, prebuild.WorkspaceID)
waiter.WaitFor(coderdtest.AgentsReady) waiter.WaitFor(coderdtest.AgentsReady)
// WHEN a workspace is created that can benefit from prebuilds // WHEN a workspace is created that can benefit from prebuilds
anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) anotherClient, anotherUser := coderdtest.CreateAnotherUser(t, client, orgID)
workspace, err := anotherClient.CreateUserWorkspace(ctx, anotherUser.ID.String(), codersdk.CreateWorkspaceRequest{ workspace, err := anotherClient.CreateUserWorkspace(ctx, anotherUser.ID.String(), codersdk.CreateWorkspaceRequest{
TemplateVersionID: version.ID, TemplateVersionID: version.ID,
TemplateVersionPresetID: preset.ID, TemplateVersionPresetID: preset.ID,
Name: "claimed-workspace", Name: "claimed-workspace",
}) })
require.NoError(t, err) require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// THEN reinitialization completes // THEN reinitialization completes
waiter.WaitFor(coderdtest.AgentsReady) waiter.WaitFor(coderdtest.AgentsReady)
var matches [][]byte var matches [][]byte
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
// THEN the agent script ran again and reused the same agent token // THEN the agent script ran again and reused the same agent token
contents, err := os.ReadFile(tempAgentLog.Name()) contents, err := os.ReadFile(tempAgentLog.Name())
if err != nil { if err != nil {
return false return false
} }
// UUID regex pattern (matches UUID v4-like strings) // UUID regex pattern (matches UUID v4-like strings)
uuidRegex := regexp.MustCompile(`\bCODER_AGENT_TOKEN=(.+)\b`) uuidRegex := regexp.MustCompile(`\bCODER_AGENT_TOKEN=(.+)\b`)
matches = uuidRegex.FindAll(contents, -1) matches = uuidRegex.FindAll(contents, -1)
// When an agent reinitializes, we expect it to run startup scripts again. // When an agent reinitializes, we expect it to run startup scripts again.
// As such, we expect to have written the agent environment to the temp file twice. // As such, we expect to have written the agent environment to the temp file twice.
// Once on initial startup and then once on reinitialization. // Once on initial startup and then once on reinitialization.
return len(matches) == 2 return len(matches) == 2
}, testutil.WaitLong, testutil.IntervalMedium) }, testutil.WaitLong, testutil.IntervalMedium)
require.Equal(t, matches[0], matches[1]) require.Equal(t, matches[0], matches[1])
})
}
} }
type setupResp struct { type setupResp struct {