mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
fix: Prevent race between provisionerd connect and close (#6206)
* fix: Prevent race between provisionerd connect and close * test: Add detection for provisioner creation after test completion
This commit is contained in:
committed by
GitHub
parent
cde7ff8a2d
commit
860e2829c5
@ -177,7 +177,17 @@ func (p *Server) connect(ctx context.Context) {
|
|||||||
p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
|
p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Ensure connection is not left hanging during a race between
|
||||||
|
// close and dial succeeding.
|
||||||
|
p.mutex.Lock()
|
||||||
|
if p.isClosed() {
|
||||||
|
client.DRPCConn().Close()
|
||||||
|
p.mutex.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
p.clientValue.Store(client)
|
p.clientValue.Store(client)
|
||||||
|
p.mutex.Unlock()
|
||||||
|
|
||||||
p.opts.Logger.Debug(context.Background(), "connected")
|
p.opts.Logger.Debug(context.Background(), "connected")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -390,7 +400,8 @@ func retryable(err error) bool {
|
|||||||
// is not retryable() or the context expires.
|
// is not retryable() or the context expires.
|
||||||
func (p *Server) clientDoWithRetries(
|
func (p *Server) clientDoWithRetries(
|
||||||
ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) (
|
ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) (
|
||||||
any, error) {
|
any, error,
|
||||||
|
) {
|
||||||
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); {
|
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); {
|
||||||
client, ok := p.client()
|
client, ok := p.client()
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -55,14 +55,22 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("InstantClose", func(t *testing.T) {
|
t.Run("InstantClose", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil
|
||||||
}, provisionerd.Provisioners{})
|
}, provisionerd.Provisioners{})
|
||||||
require.NoError(t, closer.Close())
|
require.NoError(t, closer.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ConnectErrorClose", func(t *testing.T) {
|
t.Run("ConnectErrorClose", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
completeChan := make(chan struct{})
|
completeChan := make(chan struct{})
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
defer close(completeChan)
|
defer close(completeChan)
|
||||||
@ -77,10 +85,14 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
// the job provided is empty. This is to show it successfully
|
// the job provided is empty. This is to show it successfully
|
||||||
// tried to get a job, but none were available.
|
// tried to get a job, but none were available.
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
completeChan := make(chan struct{})
|
completeChan := make(chan struct{})
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
acquireJobAttempt := 0
|
acquireJobAttempt := 0
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if acquireJobAttempt == 1 {
|
if acquireJobAttempt == 1 {
|
||||||
close(completeChan)
|
close(completeChan)
|
||||||
@ -97,13 +109,17 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("CloseCancelsJob", func(t *testing.T) {
|
t.Run("CloseCancelsJob", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
completeChan := make(chan struct{})
|
completeChan := make(chan struct{})
|
||||||
var completed sync.Once
|
var completed sync.Once
|
||||||
var closer io.Closer
|
var closer io.Closer
|
||||||
var closerMutex sync.Mutex
|
var closerMutex sync.Mutex
|
||||||
closerMutex.Lock()
|
closerMutex.Lock()
|
||||||
closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
return &proto.AcquiredJob{
|
return &proto.AcquiredJob{
|
||||||
JobId: "test",
|
JobId: "test",
|
||||||
@ -127,7 +143,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
||||||
closerMutex.Lock()
|
closerMutex.Lock()
|
||||||
defer closerMutex.Unlock()
|
defer closerMutex.Unlock()
|
||||||
@ -144,13 +160,17 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
// Ensures tars with "../../../etc/passwd" as the path
|
// Ensures tars with "../../../etc/passwd" as the path
|
||||||
// are not allowed to run, and will fail the job.
|
// are not allowed to run, and will fail the job.
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
completeChan = make(chan struct{})
|
completeChan = make(chan struct{})
|
||||||
completeOnce sync.Once
|
completeOnce sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
return &proto.AcquiredJob{
|
return &proto.AcquiredJob{
|
||||||
JobId: "test",
|
JobId: "test",
|
||||||
@ -172,7 +192,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{}),
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}),
|
||||||
})
|
})
|
||||||
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
|
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
|
||||||
require.NoError(t, closer.Close())
|
require.NoError(t, closer.Close())
|
||||||
@ -180,13 +200,17 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
|
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
completeChan = make(chan struct{})
|
completeChan = make(chan struct{})
|
||||||
completeOnce sync.Once
|
completeOnce sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
return &proto.AcquiredJob{
|
return &proto.AcquiredJob{
|
||||||
JobId: "test",
|
JobId: "test",
|
||||||
@ -210,7 +234,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
||||||
<-stream.Context().Done()
|
<-stream.Context().Done()
|
||||||
return nil
|
return nil
|
||||||
@ -223,6 +247,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("TemplateImport", func(t *testing.T) {
|
t.Run("TemplateImport", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
didComplete atomic.Bool
|
didComplete atomic.Bool
|
||||||
didLog atomic.Bool
|
didLog atomic.Bool
|
||||||
@ -234,7 +262,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if !didAcquireJob.CAS(false, true) {
|
if !didAcquireJob.CAS(false, true) {
|
||||||
completeOnce.Do(func() { close(completeChan) })
|
completeOnce.Do(func() { close(completeChan) })
|
||||||
@ -270,7 +298,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
||||||
data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt"))
|
data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -332,6 +360,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("TemplateDryRun", func(t *testing.T) {
|
t.Run("TemplateDryRun", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
didComplete atomic.Bool
|
didComplete atomic.Bool
|
||||||
didLog atomic.Bool
|
didLog atomic.Bool
|
||||||
@ -355,7 +387,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if !didAcquireJob.CAS(false, true) {
|
if !didAcquireJob.CAS(false, true) {
|
||||||
completeOnce.Do(func() { close(completeChan) })
|
completeOnce.Do(func() { close(completeChan) })
|
||||||
@ -394,7 +426,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
err := stream.Send(&sdkproto.Provision_Response{
|
err := stream.Send(&sdkproto.Provision_Response{
|
||||||
Type: &sdkproto.Provision_Response_Complete{
|
Type: &sdkproto.Provision_Response_Complete{
|
||||||
@ -417,6 +449,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("WorkspaceBuild", func(t *testing.T) {
|
t.Run("WorkspaceBuild", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
didComplete atomic.Bool
|
didComplete atomic.Bool
|
||||||
didLog atomic.Bool
|
didLog atomic.Bool
|
||||||
@ -426,7 +462,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if !didAcquireJob.CAS(false, true) {
|
if !didAcquireJob.CAS(false, true) {
|
||||||
completeOnce.Do(func() { close(completeChan) })
|
completeOnce.Do(func() { close(completeChan) })
|
||||||
@ -458,7 +494,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
err := stream.Send(&sdkproto.Provision_Response{
|
err := stream.Send(&sdkproto.Provision_Response{
|
||||||
Type: &sdkproto.Provision_Response_Log{
|
Type: &sdkproto.Provision_Response_Log{
|
||||||
@ -488,6 +524,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) {
|
t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
didComplete atomic.Bool
|
didComplete atomic.Bool
|
||||||
didLog atomic.Bool
|
didLog atomic.Bool
|
||||||
@ -498,7 +538,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if !didAcquireJob.CAS(false, true) {
|
if !didAcquireJob.CAS(false, true) {
|
||||||
completeOnce.Do(func() { close(completeChan) })
|
completeOnce.Do(func() { close(completeChan) })
|
||||||
@ -539,7 +579,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
err := stream.Send(&sdkproto.Provision_Response{
|
err := stream.Send(&sdkproto.Provision_Response{
|
||||||
Type: &sdkproto.Provision_Response_Log{
|
Type: &sdkproto.Provision_Response_Log{
|
||||||
@ -579,6 +619,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("WorkspaceBuildFailComplete", func(t *testing.T) {
|
t.Run("WorkspaceBuildFailComplete", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
didFail atomic.Bool
|
didFail atomic.Bool
|
||||||
didAcquireJob atomic.Bool
|
didAcquireJob atomic.Bool
|
||||||
@ -587,7 +631,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if !didAcquireJob.CAS(false, true) {
|
if !didAcquireJob.CAS(false, true) {
|
||||||
completeOnce.Do(func() { close(completeChan) })
|
completeOnce.Do(func() { close(completeChan) })
|
||||||
@ -614,7 +658,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
return stream.Send(&sdkproto.Provision_Response{
|
return stream.Send(&sdkproto.Provision_Response{
|
||||||
Type: &sdkproto.Provision_Response_Complete{
|
Type: &sdkproto.Provision_Response_Complete{
|
||||||
@ -633,12 +677,16 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Shutdown", func(t *testing.T) {
|
t.Run("Shutdown", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var updated sync.Once
|
var updated sync.Once
|
||||||
var completed sync.Once
|
var completed sync.Once
|
||||||
updateChan := make(chan struct{})
|
updateChan := make(chan struct{})
|
||||||
completeChan := make(chan struct{})
|
completeChan := make(chan struct{})
|
||||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
return &proto.AcquiredJob{
|
return &proto.AcquiredJob{
|
||||||
JobId: "test",
|
JobId: "test",
|
||||||
@ -676,7 +724,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
// Ignore the first provision message!
|
// Ignore the first provision message!
|
||||||
_, _ = stream.Recv()
|
_, _ = stream.Recv()
|
||||||
@ -714,12 +762,16 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("ShutdownFromJob", func(t *testing.T) {
|
t.Run("ShutdownFromJob", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var completed sync.Once
|
var completed sync.Once
|
||||||
var updated sync.Once
|
var updated sync.Once
|
||||||
updateChan := make(chan struct{})
|
updateChan := make(chan struct{})
|
||||||
completeChan := make(chan struct{})
|
completeChan := make(chan struct{})
|
||||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
return &proto.AcquiredJob{
|
return &proto.AcquiredJob{
|
||||||
JobId: "test",
|
JobId: "test",
|
||||||
@ -765,7 +817,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
// Ignore the first provision message!
|
// Ignore the first provision message!
|
||||||
_, _ = stream.Recv()
|
_, _ = stream.Recv()
|
||||||
@ -801,6 +853,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("ReconnectAndFail", func(t *testing.T) {
|
t.Run("ReconnectAndFail", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
second atomic.Bool
|
second atomic.Bool
|
||||||
failChan = make(chan struct{})
|
failChan = make(chan struct{})
|
||||||
@ -811,7 +867,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
completeOnce sync.Once
|
completeOnce sync.Once
|
||||||
)
|
)
|
||||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if second.Load() {
|
if second.Load() {
|
||||||
return &proto.AcquiredJob{}, nil
|
return &proto.AcquiredJob{}, nil
|
||||||
@ -854,7 +910,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
// Ignore the first provision message!
|
// Ignore the first provision message!
|
||||||
_, _ = stream.Recv()
|
_, _ = stream.Recv()
|
||||||
@ -874,6 +930,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("ReconnectAndComplete", func(t *testing.T) {
|
t.Run("ReconnectAndComplete", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
var (
|
var (
|
||||||
second atomic.Bool
|
second atomic.Bool
|
||||||
failChan = make(chan struct{})
|
failChan = make(chan struct{})
|
||||||
@ -884,7 +944,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
completeOnce sync.Once
|
completeOnce sync.Once
|
||||||
)
|
)
|
||||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
if second.Load() {
|
if second.Load() {
|
||||||
completeOnce.Do(func() { close(completeChan) })
|
completeOnce.Do(func() { close(completeChan) })
|
||||||
@ -929,7 +989,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
// Ignore the first provision message!
|
// Ignore the first provision message!
|
||||||
_, _ = stream.Recv()
|
_, _ = stream.Recv()
|
||||||
@ -947,6 +1007,10 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("UpdatesBeforeComplete", func(t *testing.T) {
|
t.Run("UpdatesBeforeComplete", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
logger := slogtest.Make(t, nil)
|
logger := slogtest.Make(t, nil)
|
||||||
m := sync.Mutex{}
|
m := sync.Mutex{}
|
||||||
var ops []string
|
var ops []string
|
||||||
@ -954,7 +1018,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
completeOnce := sync.Once{}
|
completeOnce := sync.Once{}
|
||||||
|
|
||||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
@ -1004,7 +1068,7 @@ func TestProvisionerd(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}), nil
|
}), nil
|
||||||
}, provisionerd.Provisioners{
|
}, provisionerd.Provisioners{
|
||||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||||
err := stream.Send(&sdkproto.Provision_Response{
|
err := stream.Send(&sdkproto.Provision_Response{
|
||||||
Type: &sdkproto.Provision_Response_Log{
|
Type: &sdkproto.Provision_Response_Log{
|
||||||
@ -1070,7 +1134,7 @@ func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners p
|
|||||||
|
|
||||||
// Creates a provisionerd protobuf client that's connected
|
// Creates a provisionerd protobuf client that's connected
|
||||||
// to the server implementation provided.
|
// to the server implementation provided.
|
||||||
func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
|
func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
if server.failJob == nil {
|
if server.failJob == nil {
|
||||||
// Default to asserting the error from the failure, otherwise
|
// Default to asserting the error from the failure, otherwise
|
||||||
@ -1098,13 +1162,23 @@ func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestSer
|
|||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
<-closed
|
<-closed
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Error("createProvisionerDaemonClient cleanup after test was done!")
|
||||||
|
default:
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Error("called createProvisionerDaemonClient after test was done!")
|
||||||
|
default:
|
||||||
|
}
|
||||||
return proto.NewDRPCProvisionerDaemonClient(clientPipe)
|
return proto.NewDRPCProvisionerDaemonClient(clientPipe)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a provisioner protobuf client that's connected
|
// Creates a provisioner protobuf client that's connected
|
||||||
// to the server implementation provided.
|
// to the server implementation provided.
|
||||||
func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
|
func createProvisionerClient(t *testing.T, done <-chan struct{}, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
|
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@ -1124,7 +1198,17 @@ func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkprot
|
|||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
<-closed
|
<-closed
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Error("createProvisionerClient cleanup after test was done!")
|
||||||
|
default:
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
t.Error("called createProvisionerClient after test was done!")
|
||||||
|
default:
|
||||||
|
}
|
||||||
return sdkproto.NewDRPCProvisionerClient(clientPipe)
|
return sdkproto.NewDRPCProvisionerClient(clientPipe)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user