fix: Prevent race between provisionerd connect and close (#6206)

* fix: Prevent race between provisionerd connect and close

* test: Add detection for provisioner creation after test completion
This commit is contained in:
Mathias Fredriksson
2023-02-14 18:37:43 +02:00
committed by GitHub
parent cde7ff8a2d
commit 860e2829c5
2 changed files with 126 additions and 31 deletions

View File

@ -177,7 +177,17 @@ func (p *Server) connect(ctx context.Context) {
p.opts.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
// Ensure connection is not left hanging during a race between
// close and dial succeeding.
p.mutex.Lock()
if p.isClosed() {
client.DRPCConn().Close()
p.mutex.Unlock()
break
}
p.clientValue.Store(client)
p.mutex.Unlock()
p.opts.Logger.Debug(context.Background(), "connected")
break
}
@ -390,7 +400,8 @@ func retryable(err error) bool {
// is not retryable() or the context expires.
func (p *Server) clientDoWithRetries(
ctx context.Context, f func(context.Context, proto.DRPCProvisionerDaemonClient) (any, error)) (
any, error) {
any, error,
) {
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); {
client, ok := p.client()
if !ok {

View File

@ -55,14 +55,22 @@ func TestProvisionerd(t *testing.T) {
t.Run("InstantClose", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{}), nil
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil
}, provisionerd.Provisioners{})
require.NoError(t, closer.Close())
})
t.Run("ConnectErrorClose", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
completeChan := make(chan struct{})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
defer close(completeChan)
@ -77,10 +85,14 @@ func TestProvisionerd(t *testing.T) {
// the job provided is empty. This is to show it successfully
// tried to get a job, but none were available.
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
completeChan := make(chan struct{})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
acquireJobAttempt := 0
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if acquireJobAttempt == 1 {
close(completeChan)
@ -97,13 +109,17 @@ func TestProvisionerd(t *testing.T) {
t.Run("CloseCancelsJob", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
completeChan := make(chan struct{})
var completed sync.Once
var closer io.Closer
var closerMutex sync.Mutex
closerMutex.Lock()
closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
@ -127,7 +143,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
closerMutex.Lock()
defer closerMutex.Unlock()
@ -144,13 +160,17 @@ func TestProvisionerd(t *testing.T) {
// Ensures tars with "../../../etc/passwd" as the path
// are not allowed to run, and will fail the job.
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
@ -172,7 +192,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{}),
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
@ -180,13 +200,17 @@ func TestProvisionerd(t *testing.T) {
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
@ -210,7 +234,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
<-stream.Context().Done()
return nil
@ -223,6 +247,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("TemplateImport", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
@ -234,7 +262,7 @@ func TestProvisionerd(t *testing.T) {
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
@ -270,7 +298,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(request *sdkproto.Parse_Request, stream sdkproto.DRPCProvisioner_ParseStream) error {
data, err := os.ReadFile(filepath.Join(request.Directory, "test.txt"))
require.NoError(t, err)
@ -332,6 +360,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("TemplateDryRun", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
@ -355,7 +387,7 @@ func TestProvisionerd(t *testing.T) {
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
@ -394,7 +426,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
@ -417,6 +449,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("WorkspaceBuild", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
@ -426,7 +462,7 @@ func TestProvisionerd(t *testing.T) {
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
@ -458,7 +494,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
@ -488,6 +524,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
@ -498,7 +538,7 @@ func TestProvisionerd(t *testing.T) {
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
@ -539,7 +579,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
@ -579,6 +619,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("WorkspaceBuildFailComplete", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didFail atomic.Bool
didAcquireJob atomic.Bool
@ -587,7 +631,7 @@ func TestProvisionerd(t *testing.T) {
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if !didAcquireJob.CAS(false, true) {
completeOnce.Do(func() { close(completeChan) })
@ -614,7 +658,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
return stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Complete{
@ -633,12 +677,16 @@ func TestProvisionerd(t *testing.T) {
t.Run("Shutdown", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var updated sync.Once
var completed sync.Once
updateChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
@ -676,7 +724,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
@ -714,12 +762,16 @@ func TestProvisionerd(t *testing.T) {
t.Run("ShutdownFromJob", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var completed sync.Once
var updated sync.Once
updateChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
return &proto.AcquiredJob{
JobId: "test",
@ -765,7 +817,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
@ -801,6 +853,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("ReconnectAndFail", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
second atomic.Bool
failChan = make(chan struct{})
@ -811,7 +867,7 @@ func TestProvisionerd(t *testing.T) {
completeOnce sync.Once
)
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if second.Load() {
return &proto.AcquiredJob{}, nil
@ -854,7 +910,7 @@ func TestProvisionerd(t *testing.T) {
}
return client, nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
@ -874,6 +930,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("ReconnectAndComplete", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
second atomic.Bool
failChan = make(chan struct{})
@ -884,7 +944,7 @@ func TestProvisionerd(t *testing.T) {
completeOnce sync.Once
)
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, provisionerDaemonTestServer{
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
if second.Load() {
completeOnce.Do(func() { close(completeChan) })
@ -929,7 +989,7 @@ func TestProvisionerd(t *testing.T) {
}
return client, nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
// Ignore the first provision message!
_, _ = stream.Recv()
@ -947,6 +1007,10 @@ func TestProvisionerd(t *testing.T) {
t.Run("UpdatesBeforeComplete", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
logger := slogtest.Make(t, nil)
m := sync.Mutex{}
var ops []string
@ -954,7 +1018,7 @@ func TestProvisionerd(t *testing.T) {
completeOnce := sync.Once{}
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, provisionerDaemonTestServer{
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJob: func(ctx context.Context, _ *proto.Empty) (*proto.AcquiredJob, error) {
m.Lock()
defer m.Unlock()
@ -1004,7 +1068,7 @@ func TestProvisionerd(t *testing.T) {
},
}), nil
}, provisionerd.Provisioners{
"someprovisioner": createProvisionerClient(t, provisionerTestServer{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
provision: func(stream sdkproto.DRPCProvisioner_ProvisionStream) error {
err := stream.Send(&sdkproto.Provision_Response{
Type: &sdkproto.Provision_Response_Log{
@ -1070,7 +1134,7 @@ func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, provisioners p
// Creates a provisionerd protobuf client that's connected
// to the server implementation provided.
func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
t.Helper()
if server.failJob == nil {
// Default to asserting the error from the failure, otherwise
@ -1098,13 +1162,23 @@ func createProvisionerDaemonClient(t *testing.T, server provisionerDaemonTestSer
t.Cleanup(func() {
cancelFunc()
<-closed
select {
case <-done:
t.Error("createProvisionerDaemonClient cleanup after test was done!")
default:
}
})
select {
case <-done:
t.Error("called createProvisionerDaemonClient after test was done!")
default:
}
return proto.NewDRPCProvisionerDaemonClient(clientPipe)
}
// Creates a provisioner protobuf client that's connected
// to the server implementation provided.
func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
func createProvisionerClient(t *testing.T, done <-chan struct{}, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
t.Helper()
clientPipe, serverPipe := provisionersdk.MemTransportPipe()
t.Cleanup(func() {
@ -1124,7 +1198,17 @@ func createProvisionerClient(t *testing.T, server provisionerTestServer) sdkprot
t.Cleanup(func() {
cancelFunc()
<-closed
select {
case <-done:
t.Error("createProvisionerClient cleanup after test was done!")
default:
}
})
select {
case <-done:
t.Error("called createProvisionerClient after test was done!")
default:
}
return sdkproto.NewDRPCProvisionerClient(clientPipe)
}