mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
feat: add Acquirer to provisionerdserver pkg (#9658)
* chore: add Acquirer to provisionerdserver pkg Signed-off-by: Spike Curtis <spike@coder.com> * code review improvements & fixes Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
492
coderd/provisionerdserver/acquirer.go
Normal file
492
coderd/provisionerdserver/acquirer.go
Normal file
@ -0,0 +1,492 @@
|
||||
package provisionerdserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
const (
|
||||
EventJobPosted = "provisioner_job_posted"
|
||||
dbMaxBackoff = 10 * time.Second
|
||||
// backPollDuration is the period for the backup polling described in Acquirer comment
|
||||
backupPollDuration = 30 * time.Second
|
||||
)
|
||||
|
||||
// Acquirer is shared among multiple routines that need to call
|
||||
// database.Store.AcquireProvisionerJob. The callers that acquire jobs are called "acquirees". The
|
||||
// goal is to minimize polling the database (i.e. lower our average query rate) and simplify the
|
||||
// acquiree's logic by handling retrying the database if a job is not available at the time of the
|
||||
// call.
|
||||
//
|
||||
// When multiple acquirees share a set of provisioner types and tags, we define them as part of the
|
||||
// same "domain". Only one acquiree from each domain may query the database at a time. If the
|
||||
// database returns no jobs for that acquiree, the entire domain waits until the Acquirer is
|
||||
// notified over the pubsub of a new job acceptable to the domain.
|
||||
//
|
||||
// As a backup to pubsub notifications, each domain is allowed to query periodically once every 30s.
|
||||
// This ensures jobs are not stuck permanently if the service that created them fails to publish
|
||||
// (e.g. a crash).
|
||||
type Acquirer struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
store AcquirerStore
|
||||
ps pubsub.Pubsub
|
||||
|
||||
mu sync.Mutex
|
||||
q map[dKey]domain
|
||||
|
||||
// testing only
|
||||
backupPollDuration time.Duration
|
||||
}
|
||||
|
||||
type AcquirerOption func(*Acquirer)
|
||||
|
||||
func TestingBackupPollDuration(dur time.Duration) AcquirerOption {
|
||||
return func(a *Acquirer) {
|
||||
a.backupPollDuration = dur
|
||||
}
|
||||
}
|
||||
|
||||
// AcquirerStore is the subset of database.Store that the Acquirer needs
|
||||
type AcquirerStore interface {
|
||||
AcquireProvisionerJob(context.Context, database.AcquireProvisionerJobParams) (database.ProvisionerJob, error)
|
||||
}
|
||||
|
||||
func NewAcquirer(ctx context.Context, logger slog.Logger, store AcquirerStore, ps pubsub.Pubsub,
|
||||
opts ...AcquirerOption,
|
||||
) *Acquirer {
|
||||
a := &Acquirer{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
store: store,
|
||||
ps: ps,
|
||||
q: make(map[dKey]domain),
|
||||
backupPollDuration: backupPollDuration,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
a.subscribe()
|
||||
return a
|
||||
}
|
||||
|
||||
// AcquireJob acquires a job with one of the given provisioner types and compatible
|
||||
// tags from the database. The call blocks until a job is acquired, the context is
|
||||
// done, or the database returns an error _other_ than that no jobs are available.
|
||||
// If no jobs are available, this method handles retrying as appropriate.
|
||||
func (a *Acquirer) AcquireJob(
|
||||
ctx context.Context, worker uuid.UUID, pt []database.ProvisionerType, tags Tags,
|
||||
) (
|
||||
retJob database.ProvisionerJob, retErr error,
|
||||
) {
|
||||
logger := a.logger.With(
|
||||
slog.F("worker_id", worker),
|
||||
slog.F("provisioner_types", pt),
|
||||
slog.F("tags", tags))
|
||||
logger.Debug(ctx, "acquiring job")
|
||||
dk := domainKey(pt, tags)
|
||||
dbTags, err := tags.ToJSON()
|
||||
if err != nil {
|
||||
return database.ProvisionerJob{}, err
|
||||
}
|
||||
// buffer of 1 so that cancel doesn't deadlock while writing to the channel
|
||||
clearance := make(chan struct{}, 1)
|
||||
//nolint:gocritic // Provisionerd has specific authz rules.
|
||||
principal := dbauthz.AsProvisionerd(ctx)
|
||||
for {
|
||||
a.want(pt, tags, clearance)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := ctx.Err()
|
||||
logger.Debug(ctx, "acquiring job canceled", slog.Error(err))
|
||||
internalError := a.cancel(dk, clearance)
|
||||
if internalError != nil {
|
||||
// internalError takes precedence
|
||||
return database.ProvisionerJob{}, internalError
|
||||
}
|
||||
return database.ProvisionerJob{}, err
|
||||
case <-clearance:
|
||||
logger.Debug(ctx, "got clearance to call database")
|
||||
job, err := a.store.AcquireProvisionerJob(principal, database.AcquireProvisionerJobParams{
|
||||
StartedAt: sql.NullTime{
|
||||
Time: dbtime.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
WorkerID: uuid.NullUUID{
|
||||
UUID: worker,
|
||||
Valid: true,
|
||||
},
|
||||
Types: pt,
|
||||
Tags: dbTags,
|
||||
})
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
logger.Debug(ctx, "no job available")
|
||||
continue
|
||||
}
|
||||
// we are not going to retry, so signal we are done
|
||||
internalError := a.done(dk, clearance)
|
||||
if internalError != nil {
|
||||
// internal error takes precedence
|
||||
return database.ProvisionerJob{}, internalError
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "error attempting to acquire job", slog.Error(err))
|
||||
return database.ProvisionerJob{}, xerrors.Errorf("failed to acquire job: %w", err)
|
||||
}
|
||||
logger.Debug(ctx, "successfully acquired job")
|
||||
return job, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// want signals that an acquiree wants clearance to query for a job with the given dKey.
|
||||
func (a *Acquirer) want(pt []database.ProvisionerType, tags Tags, clearance chan<- struct{}) {
|
||||
dk := domainKey(pt, tags)
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
cleared := false
|
||||
d, ok := a.q[dk]
|
||||
if !ok {
|
||||
ctx, cancel := context.WithCancel(a.ctx)
|
||||
d = domain{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
a: a,
|
||||
key: dk,
|
||||
pt: pt,
|
||||
tags: tags,
|
||||
acquirees: make(map[chan<- struct{}]*acquiree),
|
||||
}
|
||||
a.q[dk] = d
|
||||
go d.poll(a.backupPollDuration)
|
||||
// this is a new request for this dKey, so is cleared.
|
||||
cleared = true
|
||||
}
|
||||
w, ok := d.acquirees[clearance]
|
||||
if !ok {
|
||||
w = &acquiree{clearance: clearance}
|
||||
d.acquirees[clearance] = w
|
||||
}
|
||||
// pending means that we got a job posting for this dKey while we were
|
||||
// querying, so we should clear this acquiree to retry another time.
|
||||
if w.pending {
|
||||
cleared = true
|
||||
w.pending = false
|
||||
}
|
||||
w.inProgress = cleared
|
||||
if cleared {
|
||||
// this won't block because clearance is buffered.
|
||||
clearance <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// cancel signals that an acquiree no longer wants clearance to query. Any error returned is a serious internal error
|
||||
// indicating that integrity of the internal state is corrupted by a code bug.
|
||||
func (a *Acquirer) cancel(dk dKey, clearance chan<- struct{}) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
d, ok := a.q[dk]
|
||||
if !ok {
|
||||
// this is a code error, as something removed the domain early, or cancel
|
||||
// was called twice.
|
||||
err := xerrors.New("cancel for domain that doesn't exist")
|
||||
a.logger.Critical(a.ctx, "internal error", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
w, ok := d.acquirees[clearance]
|
||||
if !ok {
|
||||
// this is a code error, as something removed the acquiree early, or cancel
|
||||
// was called twice.
|
||||
err := xerrors.New("cancel for an acquiree that doesn't exist")
|
||||
a.logger.Critical(a.ctx, "internal error", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
delete(d.acquirees, clearance)
|
||||
if w.inProgress && len(d.acquirees) > 0 {
|
||||
// this one canceled before querying, so give another acquiree a chance
|
||||
// instead
|
||||
for _, other := range d.acquirees {
|
||||
if other.inProgress {
|
||||
err := xerrors.New("more than one acquiree in progress for same key")
|
||||
a.logger.Critical(a.ctx, "internal error", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
other.inProgress = true
|
||||
other.clearance <- struct{}{}
|
||||
break // just one
|
||||
}
|
||||
}
|
||||
if len(d.acquirees) == 0 {
|
||||
d.cancel()
|
||||
delete(a.q, dk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// done signals that the acquiree has completed acquiring a job (usually successfully, but we also get this call if
|
||||
// there is a database error other than ErrNoRows). Any error returned is a serious internal error indicating that
|
||||
// integrity of the internal state is corrupted by a code bug.
|
||||
func (a *Acquirer) done(dk dKey, clearance chan struct{}) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
d, ok := a.q[dk]
|
||||
if !ok {
|
||||
// this is a code error, as something removed the domain early, or done
|
||||
// was called twice.
|
||||
err := xerrors.New("done for a domain that doesn't exist")
|
||||
a.logger.Critical(a.ctx, "internal error", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
w, ok := d.acquirees[clearance]
|
||||
if !ok {
|
||||
// this is a code error, as something removed the dKey early, or done
|
||||
// was called twice.
|
||||
err := xerrors.New("done for an acquiree that doesn't exist")
|
||||
a.logger.Critical(a.ctx, "internal error", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
if !w.inProgress {
|
||||
err := xerrors.New("done acquiree was not in progress")
|
||||
a.logger.Critical(a.ctx, "internal error", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
delete(d.acquirees, clearance)
|
||||
if len(d.acquirees) == 0 {
|
||||
d.cancel()
|
||||
delete(a.q, dk)
|
||||
return nil
|
||||
}
|
||||
// in the mainline, this means that the acquiree successfully got a job.
|
||||
// if any others are waiting, clear one of them to try to get a job next so
|
||||
// that we process the jobs until there are no more acquirees or the database
|
||||
// is empty of jobs meeting our criteria
|
||||
for _, other := range d.acquirees {
|
||||
if other.inProgress {
|
||||
err := xerrors.New("more than one acquiree in progress for same key")
|
||||
a.logger.Critical(a.ctx, "internal error", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
other.inProgress = true
|
||||
other.clearance <- struct{}{}
|
||||
break // just one
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Acquirer) subscribe() {
|
||||
subscribed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(subscribed)
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
eb.MaxInterval = dbMaxBackoff
|
||||
bkoff := backoff.WithContext(eb, a.ctx)
|
||||
var cancel context.CancelFunc
|
||||
err := backoff.Retry(func() error {
|
||||
cancelFn, err := a.ps.SubscribeWithErr(EventJobPosted, a.jobPosted)
|
||||
if err != nil {
|
||||
a.logger.Warn(a.ctx, "failed to subscribe to job postings", slog.Error(err))
|
||||
return err
|
||||
}
|
||||
cancel = cancelFn
|
||||
return nil
|
||||
}, bkoff)
|
||||
if err != nil {
|
||||
if a.ctx.Err() == nil {
|
||||
a.logger.Error(a.ctx, "code bug: retry failed before context canceled", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
defer cancel()
|
||||
bkoff.Reset()
|
||||
a.logger.Debug(a.ctx, "subscribed to job postings")
|
||||
|
||||
// unblock the outer function from returning
|
||||
subscribed <- struct{}{}
|
||||
|
||||
// hold subscriptions open until context is canceled
|
||||
<-a.ctx.Done()
|
||||
}()
|
||||
<-subscribed
|
||||
}
|
||||
|
||||
func (a *Acquirer) jobPosted(ctx context.Context, message []byte, err error) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
a.logger.Warn(a.ctx, "pubsub may have dropped job postings")
|
||||
a.clearOrPendAll()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
a.logger.Warn(a.ctx, "unhandled pubsub error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
posting := JobPosting{}
|
||||
err = json.Unmarshal(message, &posting)
|
||||
if err != nil {
|
||||
a.logger.Error(a.ctx, "unable to parse job posting",
|
||||
slog.F("message", string(message)),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "got job posting", slog.F("posting", posting))
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
for _, d := range a.q {
|
||||
if d.contains(posting) {
|
||||
a.clearOrPendLocked(d)
|
||||
// we only need to wake up a single domain since there is only one
|
||||
// new job available
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Acquirer) clearOrPendAll() {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
for _, d := range a.q {
|
||||
a.clearOrPendLocked(d)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Acquirer) clearOrPend(d domain) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if len(d.acquirees) == 0 {
|
||||
// this can happen if the domain is removed right around the time the
|
||||
// backup poll (which calls this function) triggers. Nothing to do
|
||||
// since there are no acquirees.
|
||||
return
|
||||
}
|
||||
a.clearOrPendLocked(d)
|
||||
}
|
||||
|
||||
func (*Acquirer) clearOrPendLocked(d domain) {
|
||||
// MUST BE CALLED HOLDING THE a.mu LOCK
|
||||
var nominee *acquiree
|
||||
for _, w := range d.acquirees {
|
||||
if nominee == nil {
|
||||
nominee = w
|
||||
}
|
||||
// acquiree in progress always takes precedence, since we don't want to
|
||||
// wake up more than one acquiree per dKey at a time.
|
||||
if w.inProgress {
|
||||
nominee = w
|
||||
break
|
||||
}
|
||||
}
|
||||
if nominee.inProgress {
|
||||
nominee.pending = true
|
||||
return
|
||||
}
|
||||
nominee.inProgress = true
|
||||
nominee.clearance <- struct{}{}
|
||||
}
|
||||
|
||||
type dKey string
|
||||
|
||||
// domainKey generates a canonical map key for the given provisioner types and
|
||||
// tags. It uses the null byte (0x00) as a delimiter because it is an
|
||||
// unprintable control character and won't show up in any "reasonable" set of
|
||||
// string tags, even in non-Latin scripts. It is important that Tags are
|
||||
// validated not to contain this control character prior to use.
|
||||
func domainKey(pt []database.ProvisionerType, tags Tags) dKey {
|
||||
// make a copy of pt before sorting, so that we don't mutate the original
|
||||
// slice or underlying array.
|
||||
pts := make([]database.ProvisionerType, len(pt))
|
||||
copy(pts, pt)
|
||||
slices.Sort(pts)
|
||||
sb := strings.Builder{}
|
||||
for _, t := range pts {
|
||||
_, _ = sb.WriteString(string(t))
|
||||
_ = sb.WriteByte(0x00)
|
||||
}
|
||||
_ = sb.WriteByte(0x00)
|
||||
var keys []string
|
||||
for k := range tags {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
_, _ = sb.WriteString(k)
|
||||
_ = sb.WriteByte(0x00)
|
||||
_, _ = sb.WriteString(tags[k])
|
||||
_ = sb.WriteByte(0x00)
|
||||
}
|
||||
return dKey(sb.String())
|
||||
}
|
||||
|
||||
// acquiree represents a specific client of Acquirer that wants to acquire a job
|
||||
type acquiree struct {
|
||||
clearance chan<- struct{}
|
||||
// inProgress is true when the acquiree was granted clearance and a query
|
||||
// is possibly in progress.
|
||||
inProgress bool
|
||||
// pending is true if we get a job posting while a query is in progress, so
|
||||
// that we know to try again, even if we didn't get a job on the query.
|
||||
pending bool
|
||||
}
|
||||
|
||||
// domain represents a set of acquirees with the same provisioner types and
|
||||
// tags. Acquirees in the same domain are restricted such that only one queries
|
||||
// the database at a time.
|
||||
type domain struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
a *Acquirer
|
||||
key dKey
|
||||
pt []database.ProvisionerType
|
||||
tags Tags
|
||||
acquirees map[chan<- struct{}]*acquiree
|
||||
}
|
||||
|
||||
func (d domain) contains(p JobPosting) bool {
|
||||
if !slices.Contains(d.pt, p.ProvisionerType) {
|
||||
return false
|
||||
}
|
||||
for k, v := range p.Tags {
|
||||
dv, ok := d.tags[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if v != dv {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (d domain) poll(dur time.Duration) {
|
||||
tkr := time.NewTicker(dur)
|
||||
defer tkr.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-d.ctx.Done():
|
||||
return
|
||||
case <-tkr.C:
|
||||
d.a.clearOrPend(d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type JobPosting struct {
|
||||
ProvisionerType database.ProvisionerType `json:"type"`
|
||||
Tags map[string]string `json:"tags"`
|
||||
}
|
512
coderd/provisionerdserver/acquirer_test.go
Normal file
512
coderd/provisionerdserver/acquirer_test.go
Normal file
@ -0,0 +1,512 @@
|
||||
package provisionerdserver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
// TestAcquirer_Store tests that a database.Store is accepted as a provisionerdserver.AcquirerStore
|
||||
func TestAcquirer_Store(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := dbfake.New()
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
_ = provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps)
|
||||
}
|
||||
|
||||
func TestAcquirer_Single(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := newFakeOrderedStore()
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
||||
|
||||
workerID := uuid.New()
|
||||
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
||||
tags := provisionerdserver.Tags{
|
||||
"foo": "bar",
|
||||
}
|
||||
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
||||
jobID := uuid.New()
|
||||
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
||||
require.NoError(t, err)
|
||||
acquiree.startAcquire(ctx, uut)
|
||||
job := acquiree.success(ctx)
|
||||
require.Equal(t, jobID, job.ID)
|
||||
require.Len(t, fs.params, 1)
|
||||
require.Equal(t, workerID, fs.params[0].WorkerID.UUID)
|
||||
}
|
||||
|
||||
// TestAcquirer_MultipleSameDomain tests multiple acquirees with the same provisioners and tags
|
||||
func TestAcquirer_MultipleSameDomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := newFakeOrderedStore()
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
||||
|
||||
acquirees := make([]*testAcquiree, 0, 10)
|
||||
jobIDs := make(map[uuid.UUID]bool)
|
||||
workerIDs := make(map[uuid.UUID]bool)
|
||||
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
||||
tags := provisionerdserver.Tags{
|
||||
"foo": "bar",
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
wID := uuid.New()
|
||||
workerIDs[wID] = true
|
||||
a := newTestAcquiree(t, wID, pt, tags)
|
||||
acquirees = append(acquirees, a)
|
||||
a.startAcquire(ctx, uut)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
jobID := uuid.New()
|
||||
jobIDs[jobID] = true
|
||||
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
gotJobIDs := make(map[uuid.UUID]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
j := acquirees[i].success(ctx)
|
||||
gotJobIDs[j.ID] = true
|
||||
}
|
||||
require.Equal(t, jobIDs, gotJobIDs)
|
||||
require.Len(t, fs.overlaps, 0)
|
||||
gotWorkerCalls := make(map[uuid.UUID]bool)
|
||||
for _, params := range fs.params {
|
||||
gotWorkerCalls[params.WorkerID.UUID] = true
|
||||
}
|
||||
require.Equal(t, workerIDs, gotWorkerCalls)
|
||||
}
|
||||
|
||||
// TestAcquirer_WaitsOnNoJobs tests that after a call that returns no jobs, Acquirer waits for a new
|
||||
// job posting before retrying
|
||||
func TestAcquirer_WaitsOnNoJobs(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := newFakeOrderedStore()
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
||||
|
||||
workerID := uuid.New()
|
||||
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
||||
tags := provisionerdserver.Tags{
|
||||
"foo": "bar",
|
||||
}
|
||||
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
||||
jobID := uuid.New()
|
||||
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
|
||||
require.NoError(t, err)
|
||||
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
||||
require.NoError(t, err)
|
||||
acquiree.startAcquire(ctx, uut)
|
||||
require.Eventually(t, func() bool {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
return len(fs.params) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
acquiree.requireBlocked()
|
||||
|
||||
// First send in some with incompatible tags & types
|
||||
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{
|
||||
"cool": "tapes",
|
||||
"strong": "bad",
|
||||
})
|
||||
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{
|
||||
"foo": "fighters",
|
||||
})
|
||||
postJob(t, ps, database.ProvisionerTypeTerraform, provisionerdserver.Tags{
|
||||
"foo": "bar",
|
||||
})
|
||||
acquiree.requireBlocked()
|
||||
|
||||
// compatible tags
|
||||
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{})
|
||||
job := acquiree.success(ctx)
|
||||
require.Equal(t, jobID, job.ID)
|
||||
}
|
||||
|
||||
// TestAcquirer_RetriesPending tests that if we get a job posting while a db call is in progress
|
||||
// we retry to acquire a job immediately, even if the first call returned no jobs. We want this
|
||||
// behavior since the query that found no jobs could have resolved before the job was posted, but
|
||||
// the query result could reach us later than the posting over the pubsub.
|
||||
func TestAcquirer_RetriesPending(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := newFakeOrderedStore()
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
||||
|
||||
workerID := uuid.New()
|
||||
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
||||
tags := provisionerdserver.Tags{
|
||||
"foo": "bar",
|
||||
}
|
||||
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
||||
jobID := uuid.New()
|
||||
|
||||
acquiree.startAcquire(ctx, uut)
|
||||
require.Eventually(t, func() bool {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
return len(fs.params) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
// First call to DB is in progress. Send in posting
|
||||
postJob(t, ps, database.ProvisionerTypeEcho, provisionerdserver.Tags{})
|
||||
// there is a race between the posting being processed and the DB call
|
||||
// returning. In either case we should retry, but we're trying to hit the
|
||||
// case where the posting is processed first, so sleep a little bit to give
|
||||
// it a chance.
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
|
||||
// Now, when first DB call returns ErrNoRows we retry.
|
||||
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
|
||||
require.NoError(t, err)
|
||||
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
job := acquiree.success(ctx)
|
||||
require.Equal(t, jobID, job.ID)
|
||||
}
|
||||
|
||||
// TestAcquirer_DifferentDomains tests that acquirees with different tags don't block each other
|
||||
func TestAcquirer_DifferentDomains(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := newFakeTaggedStore(t)
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
||||
worker0 := uuid.New()
|
||||
tags0 := provisionerdserver.Tags{
|
||||
"worker": "0",
|
||||
}
|
||||
acquiree0 := newTestAcquiree(t, worker0, pt, tags0)
|
||||
worker1 := uuid.New()
|
||||
tags1 := provisionerdserver.Tags{
|
||||
"worker": "1",
|
||||
}
|
||||
acquiree1 := newTestAcquiree(t, worker1, pt, tags1)
|
||||
jobID := uuid.New()
|
||||
fs.jobs = []database.ProvisionerJob{
|
||||
{ID: jobID, Provisioner: database.ProvisionerTypeEcho, Tags: database.StringMap{"worker": "1"}},
|
||||
}
|
||||
|
||||
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
||||
|
||||
ctx0, cancel0 := context.WithCancel(ctx)
|
||||
defer cancel0()
|
||||
acquiree0.startAcquire(ctx0, uut)
|
||||
select {
|
||||
case params := <-fs.params:
|
||||
require.Equal(t, worker0, params.WorkerID.UUID)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for call to database from worker0")
|
||||
}
|
||||
acquiree0.requireBlocked()
|
||||
|
||||
// worker1 should not be blocked by worker0, as they are different tags
|
||||
acquiree1.startAcquire(ctx, uut)
|
||||
job := acquiree1.success(ctx)
|
||||
require.Equal(t, jobID, job.ID)
|
||||
|
||||
cancel0()
|
||||
acquiree0.requireCanceled(ctx)
|
||||
}
|
||||
|
||||
func TestAcquirer_BackupPoll(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := newFakeOrderedStore()
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
uut := provisionerdserver.NewAcquirer(
|
||||
ctx, logger.Named("acquirer"), fs, ps,
|
||||
provisionerdserver.TestingBackupPollDuration(testutil.IntervalMedium),
|
||||
)
|
||||
|
||||
workerID := uuid.New()
|
||||
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
||||
tags := provisionerdserver.Tags{
|
||||
"foo": "bar",
|
||||
}
|
||||
acquiree := newTestAcquiree(t, workerID, pt, tags)
|
||||
jobID := uuid.New()
|
||||
err := fs.sendCtx(ctx, database.ProvisionerJob{}, sql.ErrNoRows)
|
||||
require.NoError(t, err)
|
||||
err = fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
||||
require.NoError(t, err)
|
||||
acquiree.startAcquire(ctx, uut)
|
||||
job := acquiree.success(ctx)
|
||||
require.Equal(t, jobID, job.ID)
|
||||
}
|
||||
|
||||
// TestAcquirer_UnblockOnCancel tests that a canceled call doesn't block a call
|
||||
// from the same domain.
|
||||
func TestAcquirer_UnblockOnCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := newFakeOrderedStore()
|
||||
ps := pubsub.NewInMemory()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
pt := []database.ProvisionerType{database.ProvisionerTypeEcho}
|
||||
worker0 := uuid.New()
|
||||
tags := provisionerdserver.Tags{
|
||||
"foo": "bar",
|
||||
}
|
||||
acquiree0 := newTestAcquiree(t, worker0, pt, tags)
|
||||
worker1 := uuid.New()
|
||||
acquiree1 := newTestAcquiree(t, worker1, pt, tags)
|
||||
jobID := uuid.New()
|
||||
|
||||
uut := provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), fs, ps)
|
||||
|
||||
// queue up 2 responses --- we may not need both, since acquiree0 will
|
||||
// usually cancel before calling, but cancel is async, so it might call.
|
||||
for i := 0; i < 2; i++ {
|
||||
err := fs.sendCtx(ctx, database.ProvisionerJob{ID: jobID}, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
ctx0, cancel0 := context.WithCancel(ctx)
|
||||
cancel0()
|
||||
acquiree0.startAcquire(ctx0, uut)
|
||||
acquiree1.startAcquire(ctx, uut)
|
||||
job := acquiree1.success(ctx)
|
||||
require.Equal(t, jobID, job.ID)
|
||||
}
|
||||
|
||||
func postJob(t *testing.T, ps pubsub.Pubsub, pt database.ProvisionerType, tags provisionerdserver.Tags) {
|
||||
t.Helper()
|
||||
msg, err := json.Marshal(provisionerdserver.JobPosting{
|
||||
ProvisionerType: pt,
|
||||
Tags: tags,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = ps.Publish(provisionerdserver.EventJobPosted, msg)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// fakeOrderedStore is a fake store that lets tests send AcquireProvisionerJob
|
||||
// results in order over a channel, and tests for overlapped calls.
|
||||
type fakeOrderedStore struct {
|
||||
jobs chan database.ProvisionerJob
|
||||
errors chan error
|
||||
|
||||
mu sync.Mutex
|
||||
params []database.AcquireProvisionerJobParams
|
||||
|
||||
// inflight and overlaps track whether any calls from workers overlap with
|
||||
// one another
|
||||
inflight map[uuid.UUID]bool
|
||||
overlaps [][]uuid.UUID
|
||||
}
|
||||
|
||||
func newFakeOrderedStore() *fakeOrderedStore {
|
||||
return &fakeOrderedStore{
|
||||
// buffer the channels so that we can queue up lots of responses to
|
||||
// occur nearly simultaneously
|
||||
jobs: make(chan database.ProvisionerJob, 100),
|
||||
errors: make(chan error, 100),
|
||||
inflight: make(map[uuid.UUID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fakeOrderedStore) AcquireProvisionerJob(
|
||||
_ context.Context, params database.AcquireProvisionerJobParams,
|
||||
) (
|
||||
database.ProvisionerJob, error,
|
||||
) {
|
||||
s.mu.Lock()
|
||||
s.params = append(s.params, params)
|
||||
for workerID := range s.inflight {
|
||||
s.overlaps = append(s.overlaps, []uuid.UUID{workerID, params.WorkerID.UUID})
|
||||
}
|
||||
s.inflight[params.WorkerID.UUID] = true
|
||||
s.mu.Unlock()
|
||||
|
||||
job := <-s.jobs
|
||||
err := <-s.errors
|
||||
|
||||
s.mu.Lock()
|
||||
delete(s.inflight, params.WorkerID.UUID)
|
||||
s.mu.Unlock()
|
||||
|
||||
return job, err
|
||||
}
|
||||
|
||||
func (s *fakeOrderedStore) sendCtx(ctx context.Context, job database.ProvisionerJob, err error) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s.jobs <- job:
|
||||
// OK
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case s.errors <- err:
|
||||
// OK
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fakeTaggedStore is a test store that allows tests to specify which jobs are
|
||||
// available, and returns them to callers with the appropriate provisioner type
|
||||
// and tags. It doesn't care about the order.
|
||||
type fakeTaggedStore struct {
|
||||
t *testing.T
|
||||
mu sync.Mutex
|
||||
jobs []database.ProvisionerJob
|
||||
params chan database.AcquireProvisionerJobParams
|
||||
}
|
||||
|
||||
func newFakeTaggedStore(t *testing.T) *fakeTaggedStore {
|
||||
return &fakeTaggedStore{
|
||||
t: t,
|
||||
params: make(chan database.AcquireProvisionerJobParams, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fakeTaggedStore) AcquireProvisionerJob(
|
||||
_ context.Context, params database.AcquireProvisionerJobParams,
|
||||
) (
|
||||
database.ProvisionerJob, error,
|
||||
) {
|
||||
defer func() { s.params <- params }()
|
||||
var tags provisionerdserver.Tags
|
||||
err := json.Unmarshal(params.Tags, &tags)
|
||||
if !assert.NoError(s.t, err) {
|
||||
return database.ProvisionerJob{}, err
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
jobLoop:
|
||||
for i, job := range s.jobs {
|
||||
if !slices.Contains(params.Types, job.Provisioner) {
|
||||
continue
|
||||
}
|
||||
for k, v := range job.Tags {
|
||||
pv, ok := tags[k]
|
||||
if !ok {
|
||||
continue jobLoop
|
||||
}
|
||||
if v != pv {
|
||||
continue jobLoop
|
||||
}
|
||||
}
|
||||
// found a job!
|
||||
s.jobs = append(s.jobs[:i], s.jobs[i+1:]...)
|
||||
return job, nil
|
||||
}
|
||||
return database.ProvisionerJob{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
// testAcquiree is a helper type that handles asynchronously calling AcquireJob
|
||||
// and asserting whether or not it returns, blocks, or is canceled.
|
||||
type testAcquiree struct {
|
||||
t *testing.T
|
||||
workerID uuid.UUID
|
||||
pt []database.ProvisionerType
|
||||
tags provisionerdserver.Tags
|
||||
ec chan error
|
||||
jc chan database.ProvisionerJob
|
||||
}
|
||||
|
||||
func newTestAcquiree(t *testing.T, workerID uuid.UUID, pt []database.ProvisionerType, tags provisionerdserver.Tags) *testAcquiree {
|
||||
return &testAcquiree{
|
||||
t: t,
|
||||
workerID: workerID,
|
||||
pt: pt,
|
||||
tags: tags,
|
||||
ec: make(chan error, 1),
|
||||
jc: make(chan database.ProvisionerJob, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *testAcquiree) startAcquire(ctx context.Context, uut *provisionerdserver.Acquirer) {
|
||||
go func() {
|
||||
j, e := uut.AcquireJob(ctx, a.workerID, a.pt, a.tags)
|
||||
a.ec <- e
|
||||
a.jc <- j
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *testAcquiree) success(ctx context.Context) database.ProvisionerJob {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
a.t.Fatal("timeout waiting for AcquireJob error")
|
||||
case err := <-a.ec:
|
||||
require.NoError(a.t, err)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
a.t.Fatal("timeout waiting for AcquireJob job")
|
||||
case job := <-a.jc:
|
||||
return job
|
||||
}
|
||||
// unhittable
|
||||
return database.ProvisionerJob{}
|
||||
}
|
||||
|
||||
func (a *testAcquiree) requireBlocked() {
|
||||
select {
|
||||
case <-a.ec:
|
||||
a.t.Fatal("AcquireJob should block")
|
||||
default:
|
||||
// OK
|
||||
}
|
||||
}
|
||||
|
||||
func (a *testAcquiree) requireCanceled(ctx context.Context) {
|
||||
select {
|
||||
case err := <-a.ec:
|
||||
require.ErrorIs(a.t, err, context.Canceled)
|
||||
case <-ctx.Done():
|
||||
a.t.Fatal("timed out waiting for AcquireJob")
|
||||
}
|
||||
select {
|
||||
case job := <-a.jc:
|
||||
require.Equal(a.t, uuid.Nil, job.ID)
|
||||
case <-ctx.Done():
|
||||
a.t.Fatal("timed out waiting for AcquireJob")
|
||||
}
|
||||
}
|
@ -79,6 +79,30 @@ type server struct {
|
||||
TimeNowFn func() time.Time
|
||||
}
|
||||
|
||||
// We use the null byte (0x00) in generating a canonical map key for tags, so
|
||||
// it cannot be used in the tag keys or values.
|
||||
|
||||
var ErrorTagsContainNullByte = xerrors.New("tags cannot contain the null byte (0x00)")
|
||||
|
||||
type Tags map[string]string
|
||||
|
||||
func (t Tags) ToJSON() (json.RawMessage, error) {
|
||||
r, err := json.Marshal(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (t Tags) Valid() error {
|
||||
for k, v := range t {
|
||||
if slices.Contains([]byte(k), 0x00) || slices.Contains([]byte(v), 0x00) {
|
||||
return ErrorTagsContainNullByte
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewServer(
|
||||
accessURL *url.URL,
|
||||
id uuid.UUID,
|
||||
|
Reference in New Issue
Block a user