mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
fix: Prevent race between provisionerd connect and close (#6206)
* fix: Prevent race between provisionerd connect and close * test: Add detection for provisioner creation after test completion
This commit is contained in:
committed by
GitHub
parent
cde7ff8a2d
commit
860e2829c5
@ -177,7 +177,17 @@ func (p *Server) connect(ctx context.Context) {
|
||||
p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
// Ensure connection is not left hanging during a race between
|
||||
// close and dial succeeding.
|
||||
p.mutex.Lock()
|
||||
if p.isClosed() {
|
||||
client.DRPCConn().Close()
|
||||
p.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
p.clientValue.Store(client)
|
||||
p.mutex.Unlock()
|
||||
|
||||
p.opts.Logger.Debug(context.Background(), "connected")
|
||||
break
|
||||
}
|
||||
@ -390,7 +400,8 @@ func retryable(err error) bool {
|
||||
// is not retryable() or the context expires.
|
||||
func (p *Server) clientDoWithRetries(
|
||||
ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) (
|
||||
any, error) {
|
||||
any, error,
|
||||
) {
|
||||
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); {
|
||||
client, ok := p.client()
|
||||
if !ok {
|
||||
|
@ -55,14 +55,22 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("InstantClose", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil
|
||||
}, provisionerd.Provisioners{})
|
||||
require.NoError(t, closer.Close())
|
||||
})
|
||||
|
||||
t.Run("ConnectErrorClose", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
completeChan := make(chan struct{})
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
defer close(completeChan)
|
||||
@ -77,10 +85,14 @@ func TestProvisionerd(t *testing.T) {
|
||||
// the job provided is empty. This is to show it successfully
|
||||
// tried to get a job, but none were available.
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
completeChan := make(chan struct{})
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
acquireJobAttempt := 0
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if acquireJobAttempt == 1 {
|
||||
close(completeChan)
|
||||
@ -97,13 +109,17 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("CloseCancelsJob", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
completeChan := make(chan struct{})
|
||||
var completed sync.Once
|
||||
var closer io.Closer
|
||||
var closerMutex sync.Mutex
|
||||
closerMutex.Lock()
|
||||
closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
return &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
@ -127,7 +143,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
||||
closerMutex.Lock()
|
||||
defer closerMutex.Unlock()
|
||||
@ -144,13 +160,17 @@ func TestProvisionerd(t *testing.T) {
|
||||
// Ensures tars with "../../../etc/passwd" as the path
|
||||
// are not allowed to run, and will fail the job.
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
completeChan = make(chan struct{})
|
||||
completeOnce sync.Once
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
return &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
@ -172,7 +192,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{}),
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}),
|
||||
})
|
||||
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
|
||||
require.NoError(t, closer.Close())
|
||||
@ -180,13 +200,17 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
completeChan = make(chan struct{})
|
||||
completeOnce sync.Once
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
return &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
@ -210,7 +234,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
||||
<-stream.Context().Done()
|
||||
return nil
|
||||
@ -223,6 +247,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("TemplateImport", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
didComplete atomic.Bool
|
||||
didLog atomic.Bool
|
||||
@ -234,7 +262,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if !didAcquireJob.CAS(false, true) {
|
||||
completeOnce.Do(func() { close(completeChan) })
|
||||
@ -270,7 +298,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
|
||||
data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt"))
|
||||
require.NoError(t, err)
|
||||
@ -332,6 +360,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("TemplateDryRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
didComplete atomic.Bool
|
||||
didLog atomic.Bool
|
||||
@ -355,7 +387,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if !didAcquireJob.CAS(false, true) {
|
||||
completeOnce.Do(func() { close(completeChan) })
|
||||
@ -394,7 +426,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
err := stream.Send(&sdkproto.Provision_Response{
|
||||
Type: &sdkproto.Provision_Response_Complete{
|
||||
@ -417,6 +449,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("WorkspaceBuild", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
didComplete atomic.Bool
|
||||
didLog atomic.Bool
|
||||
@ -426,7 +462,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if !didAcquireJob.CAS(false, true) {
|
||||
completeOnce.Do(func() { close(completeChan) })
|
||||
@ -458,7 +494,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
err := stream.Send(&sdkproto.Provision_Response{
|
||||
Type: &sdkproto.Provision_Response_Log{
|
||||
@ -488,6 +524,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
didComplete atomic.Bool
|
||||
didLog atomic.Bool
|
||||
@ -498,7 +538,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if !didAcquireJob.CAS(false, true) {
|
||||
completeOnce.Do(func() { close(completeChan) })
|
||||
@ -539,7 +579,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
err := stream.Send(&sdkproto.Provision_Response{
|
||||
Type: &sdkproto.Provision_Response_Log{
|
||||
@ -579,6 +619,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("WorkspaceBuildFailComplete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
didFail atomic.Bool
|
||||
didAcquireJob atomic.Bool
|
||||
@ -587,7 +631,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if !didAcquireJob.CAS(false, true) {
|
||||
completeOnce.Do(func() { close(completeChan) })
|
||||
@ -614,7 +658,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
return stream.Send(&sdkproto.Provision_Response{
|
||||
Type: &sdkproto.Provision_Response_Complete{
|
||||
@ -633,12 +677,16 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("Shutdown", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var updated sync.Once
|
||||
var completed sync.Once
|
||||
updateChan := make(chan struct{})
|
||||
completeChan := make(chan struct{})
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
return &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
@ -676,7 +724,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
// Ignore the first provision message!
|
||||
_, _ = stream.Recv()
|
||||
@ -714,12 +762,16 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("ShutdownFromJob", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var completed sync.Once
|
||||
var updated sync.Once
|
||||
updateChan := make(chan struct{})
|
||||
completeChan := make(chan struct{})
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
return &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
@ -765,7 +817,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
// Ignore the first provision message!
|
||||
_, _ = stream.Recv()
|
||||
@ -801,6 +853,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("ReconnectAndFail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
second atomic.Bool
|
||||
failChan = make(chan struct{})
|
||||
@ -811,7 +867,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
completeOnce sync.Once
|
||||
)
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if second.Load() {
|
||||
return &proto.AcquiredJob{}, nil
|
||||
@ -854,7 +910,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
}
|
||||
return client, nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
// Ignore the first provision message!
|
||||
_, _ = stream.Recv()
|
||||
@ -874,6 +930,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("ReconnectAndComplete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
second atomic.Bool
|
||||
failChan = make(chan struct{})
|
||||
@ -884,7 +944,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
completeOnce sync.Once
|
||||
)
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
if second.Load() {
|
||||
completeOnce.Do(func() { close(completeChan) })
|
||||
@ -929,7 +989,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
}
|
||||
return client, nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
// Ignore the first provision message!
|
||||
_, _ = stream.Recv()
|
||||
@ -947,6 +1007,10 @@ func TestProvisionerd(t *testing.T) {
|
||||
|
||||
t.Run("UpdatesBeforeComplete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
logger := slogtest.Make(t, nil)
|
||||
m := sync.Mutex{}
|
||||
var ops []string
|
||||
@ -954,7 +1018,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
completeOnce := sync.Once{}
|
||||
|
||||
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
@ -1004,7 +1068,7 @@ func TestProvisionerd(t *testing.T) {
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.Provisioners{
|
||||
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
|
||||
err := stream.Send(&sdkproto.Provision_Response{
|
||||
Type: &sdkproto.Provision_Response_Log{
|
||||
@ -1070,7 +1134,7 @@ func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners p
|
||||
|
||||
// Creates a provisionerd protobuf client that's connected
|
||||
// to the server implementation provided.
|
||||
func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
|
||||
func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
|
||||
t.Helper()
|
||||
if server.failJob == nil {
|
||||
// Default to asserting the error from the failure, otherwise
|
||||
@ -1098,13 +1162,23 @@ func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestSer
|
||||
t.Cleanup(func() {
|
||||
cancelFunc()
|
||||
<-closed
|
||||
select {
|
||||
case <-done:
|
||||
t.Error("createProvisionerDaemonClient cleanup after test was done!")
|
||||
default:
|
||||
}
|
||||
})
|
||||
select {
|
||||
case <-done:
|
||||
t.Error("called createProvisionerDaemonClient after test was done!")
|
||||
default:
|
||||
}
|
||||
return proto.NewDRPCProvisionerDaemonClient(clientPipe)
|
||||
}
|
||||
|
||||
// Creates a provisioner protobuf client that's connected
|
||||
// to the server implementation provided.
|
||||
func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
|
||||
func createProvisionerClient(t *testing.T, done <-chan struct{}, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
|
||||
t.Helper()
|
||||
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
|
||||
t.Cleanup(func() {
|
||||
@ -1124,7 +1198,17 @@ func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkprot
|
||||
t.Cleanup(func() {
|
||||
cancelFunc()
|
||||
<-closed
|
||||
select {
|
||||
case <-done:
|
||||
t.Error("createProvisionerClient cleanup after test was done!")
|
||||
default:
|
||||
}
|
||||
})
|
||||
select {
|
||||
case <-done:
|
||||
t.Error("called createProvisionerClient after test was done!")
|
||||
default:
|
||||
}
|
||||
return sdkproto.NewDRPCProvisionerClient(clientPipe)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user