chore: update testutil chan helpers (#17408)

This commit is contained in:
ケイラ
2025-04-16 09:37:09 -07:00
committed by GitHub
parent 2a76f5028e
commit f670bc31f5
45 changed files with 582 additions and 559 deletions

View File

@ -61,7 +61,7 @@ func TestInMemoryCoordination(t *testing.T) {
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
// Recv loop should be terminated by the server hanging up after Disconnect
err := testutil.RequireRecvCtx(ctx, t, uut.Wait())
err := testutil.TryReceive(ctx, t, uut.Wait())
require.ErrorIs(t, err, io.EOF)
}
@ -118,7 +118,7 @@ func TestTunnelSrcCoordController_Mainline(t *testing.T) {
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Wait())
err = testutil.TryReceive(ctx, t, uut.Wait())
require.ErrorIs(t, err, io.EOF)
}
@ -147,22 +147,22 @@ func TestTunnelSrcCoordController_AddDestination(t *testing.T) {
// THEN: Controller sends AddTunnel for the destinations
for i := range 2 {
b0 := byte(i + 1)
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
call := testutil.TryReceive(ctx, t, client1.reqs)
require.Equal(t, b0, call.req.GetAddTunnel().GetId()[0])
testutil.RequireSendCtx(ctx, t, call.err, nil)
testutil.RequireSend(ctx, t, call.err, nil)
}
_ = testutil.RequireRecvCtx(ctx, t, addDone)
_ = testutil.TryReceive(ctx, t, addDone)
// THEN: Controller sets destinations on Coordinatee
require.Contains(t, fConn.tunnelDestinations, dest1)
require.Contains(t, fConn.tunnelDestinations, dest2)
// WHEN: Closed from server side and reconnects
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
respCall := testutil.TryReceive(ctx, t, client1.resps)
testutil.RequireSend(ctx, t, respCall.err, io.EOF)
closeCall := testutil.TryReceive(ctx, t, client1.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err := testutil.TryReceive(ctx, t, cw1.Wait())
require.ErrorIs(t, err, io.EOF)
client2 := newFakeCoordinatorClient(ctx, t)
cws := make(chan tailnet.CloserWaiter)
@ -173,21 +173,21 @@ func TestTunnelSrcCoordController_AddDestination(t *testing.T) {
// THEN: should immediately send both destinations
var dests []byte
for range 2 {
call := testutil.RequireRecvCtx(ctx, t, client2.reqs)
call := testutil.TryReceive(ctx, t, client2.reqs)
dests = append(dests, call.req.GetAddTunnel().GetId()[0])
testutil.RequireSendCtx(ctx, t, call.err, nil)
testutil.RequireSend(ctx, t, call.err, nil)
}
slices.Sort(dests)
require.Equal(t, dests, []byte{1, 2})
cw2 := testutil.RequireRecvCtx(ctx, t, cws)
cw2 := testutil.TryReceive(ctx, t, cws)
// close client2
respCall = testutil.RequireRecvCtx(ctx, t, client2.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall = testutil.RequireRecvCtx(ctx, t, client2.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
respCall = testutil.TryReceive(ctx, t, client2.resps)
testutil.RequireSend(ctx, t, respCall.err, io.EOF)
closeCall = testutil.TryReceive(ctx, t, client2.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err = testutil.TryReceive(ctx, t, cw2.Wait())
require.ErrorIs(t, err, io.EOF)
}
@ -209,9 +209,9 @@ func TestTunnelSrcCoordController_RemoveDestination(t *testing.T) {
go func() {
cws <- uut.New(client1)
}()
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, nil)
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
call := testutil.TryReceive(ctx, t, client1.reqs)
testutil.RequireSend(ctx, t, call.err, nil)
cw1 := testutil.TryReceive(ctx, t, cws)
// WHEN: we remove one destination
removeDone := make(chan struct{})
@ -221,17 +221,17 @@ func TestTunnelSrcCoordController_RemoveDestination(t *testing.T) {
}()
// THEN: Controller sends RemoveTunnel for the destination
call = testutil.RequireRecvCtx(ctx, t, client1.reqs)
call = testutil.TryReceive(ctx, t, client1.reqs)
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, nil)
_ = testutil.RequireRecvCtx(ctx, t, removeDone)
testutil.RequireSend(ctx, t, call.err, nil)
_ = testutil.TryReceive(ctx, t, removeDone)
// WHEN: Closed from server side and reconnect
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
respCall := testutil.TryReceive(ctx, t, client1.resps)
testutil.RequireSend(ctx, t, respCall.err, io.EOF)
closeCall := testutil.TryReceive(ctx, t, client1.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err := testutil.TryReceive(ctx, t, cw1.Wait())
require.ErrorIs(t, err, io.EOF)
client2 := newFakeCoordinatorClient(ctx, t)
@ -240,14 +240,14 @@ func TestTunnelSrcCoordController_RemoveDestination(t *testing.T) {
}()
// THEN: should immediately resolve without sending anything
cw2 := testutil.RequireRecvCtx(ctx, t, cws)
cw2 := testutil.TryReceive(ctx, t, cws)
// close client2
respCall = testutil.RequireRecvCtx(ctx, t, client2.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall = testutil.RequireRecvCtx(ctx, t, client2.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
respCall = testutil.TryReceive(ctx, t, client2.resps)
testutil.RequireSend(ctx, t, respCall.err, io.EOF)
closeCall = testutil.TryReceive(ctx, t, client2.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err = testutil.TryReceive(ctx, t, cw2.Wait())
require.ErrorIs(t, err, io.EOF)
}
@ -274,10 +274,10 @@ func TestTunnelSrcCoordController_RemoveDestination_Error(t *testing.T) {
cws <- uut.New(client1)
}()
for range 3 {
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, nil)
call := testutil.TryReceive(ctx, t, client1.reqs)
testutil.RequireSend(ctx, t, call.err, nil)
}
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
cw1 := testutil.TryReceive(ctx, t, cws)
// WHEN: we remove all destinations
removeDone := make(chan struct{})
@ -290,22 +290,22 @@ func TestTunnelSrcCoordController_RemoveDestination_Error(t *testing.T) {
// WHEN: first RemoveTunnel call fails
theErr := xerrors.New("a bad thing happened")
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
call := testutil.TryReceive(ctx, t, client1.reqs)
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, theErr)
testutil.RequireSend(ctx, t, call.err, theErr)
// THEN: we disconnect and do not send remaining RemoveTunnel messages
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
_ = testutil.RequireRecvCtx(ctx, t, removeDone)
closeCall := testutil.TryReceive(ctx, t, client1.close)
testutil.RequireSend(ctx, t, closeCall, nil)
_ = testutil.TryReceive(ctx, t, removeDone)
// shut down
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
respCall := testutil.TryReceive(ctx, t, client1.resps)
testutil.RequireSend(ctx, t, respCall.err, io.EOF)
// triggers second close call
closeCall = testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
closeCall = testutil.TryReceive(ctx, t, client1.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err := testutil.TryReceive(ctx, t, cw1.Wait())
require.ErrorIs(t, err, theErr)
}
@ -331,10 +331,10 @@ func TestTunnelSrcCoordController_Sync(t *testing.T) {
cws <- uut.New(client1)
}()
for range 2 {
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, nil)
call := testutil.TryReceive(ctx, t, client1.reqs)
testutil.RequireSend(ctx, t, call.err, nil)
}
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
cw1 := testutil.TryReceive(ctx, t, cws)
// WHEN: we sync dest2 & dest3
syncDone := make(chan struct{})
@ -344,23 +344,23 @@ func TestTunnelSrcCoordController_Sync(t *testing.T) {
}()
// THEN: we get an add for dest3 and remove for dest1
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
call := testutil.TryReceive(ctx, t, client1.reqs)
require.Equal(t, dest3[:], call.req.GetAddTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, nil)
call = testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSend(ctx, t, call.err, nil)
call = testutil.TryReceive(ctx, t, client1.reqs)
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, nil)
testutil.RequireSend(ctx, t, call.err, nil)
testutil.RequireRecvCtx(ctx, t, syncDone)
testutil.TryReceive(ctx, t, syncDone)
// dest3 should be added to coordinatee
require.Contains(t, fConn.tunnelDestinations, dest3)
// shut down
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
respCall := testutil.TryReceive(ctx, t, client1.resps)
testutil.RequireSend(ctx, t, respCall.err, io.EOF)
closeCall := testutil.TryReceive(ctx, t, client1.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err := testutil.TryReceive(ctx, t, cw1.Wait())
require.ErrorIs(t, err, io.EOF)
}
@ -384,24 +384,24 @@ func TestTunnelSrcCoordController_AddDestination_Error(t *testing.T) {
uut.AddDestination(dest1)
}()
theErr := xerrors.New("a bad thing happened")
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, theErr)
call := testutil.TryReceive(ctx, t, client1.reqs)
testutil.RequireSend(ctx, t, call.err, theErr)
// THEN: Client is closed and exits
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
closeCall := testutil.TryReceive(ctx, t, client1.close)
testutil.RequireSend(ctx, t, closeCall, nil)
// close the resps, since the client has closed
resp := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, resp.err, net.ErrClosed)
resp := testutil.TryReceive(ctx, t, client1.resps)
testutil.RequireSend(ctx, t, resp.err, net.ErrClosed)
// this triggers a second Close() call on the client
closeCall = testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
closeCall = testutil.TryReceive(ctx, t, client1.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
err := testutil.TryReceive(ctx, t, cw1.Wait())
require.ErrorIs(t, err, theErr)
_ = testutil.RequireRecvCtx(ctx, t, addDone)
_ = testutil.TryReceive(ctx, t, addDone)
}
func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) {
@ -457,7 +457,7 @@ func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) {
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{
testutil.RequireSend(ctx, t, resps, &proto.CoordinateResponse{
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
Id: clientID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
@ -469,19 +469,19 @@ func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) {
}},
})
rfh := testutil.RequireRecvCtx(ctx, t, reqs)
rfh := testutil.TryReceive(ctx, t, reqs)
require.NotNil(t, rfh.ReadyForHandshake)
require.Len(t, rfh.ReadyForHandshake, 1)
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
go uut.Close(ctx)
dis := testutil.RequireRecvCtx(ctx, t, reqs)
dis := testutil.TryReceive(ctx, t, reqs)
require.NotNil(t, dis)
require.NotNil(t, dis.Disconnect)
close(resps)
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Wait())
err = testutil.TryReceive(ctx, t, uut.Wait())
require.ErrorIs(t, err, io.EOF)
}
@ -493,14 +493,14 @@ func coordinationTest(
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
req := testutil.TryReceive(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})
req = testutil.RequireRecvCtx(ctx, t, reqs)
req = testutil.TryReceive(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
// When we send a peer update, it should update the coordinatee
@ -519,7 +519,7 @@ func coordinationTest(
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
testutil.RequireSend(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
@ -534,11 +534,11 @@ func coordinationTest(
}()
// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
req = testutil.TryReceive(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
close(resps)
err = testutil.RequireRecvCtx(ctx, t, errCh)
err = testutil.TryReceive(ctx, t, errCh)
require.NoError(t, err)
// It should set all peers lost on the coordinatee
@ -593,12 +593,12 @@ func TestNewBasicDERPController_Mainline(t *testing.T) {
c := uut.New(fc)
ctx := testutil.Context(t, testutil.WaitShort)
expectDM := &tailcfg.DERPMap{}
testutil.RequireSendCtx(ctx, t, fc.ch, expectDM)
gotDM := testutil.RequireRecvCtx(ctx, t, fs)
testutil.RequireSend(ctx, t, fc.ch, expectDM)
gotDM := testutil.TryReceive(ctx, t, fs)
require.Equal(t, expectDM, gotDM)
err := c.Close(ctx)
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, c.Wait())
err = testutil.TryReceive(ctx, t, c.Wait())
require.ErrorIs(t, err, io.EOF)
// ensure Close is idempotent
err = c.Close(ctx)
@ -617,7 +617,7 @@ func TestNewBasicDERPController_RecvErr(t *testing.T) {
}
c := uut.New(fc)
ctx := testutil.Context(t, testutil.WaitShort)
err := testutil.RequireRecvCtx(ctx, t, c.Wait())
err := testutil.TryReceive(ctx, t, c.Wait())
require.ErrorIs(t, err, expectedErr)
// ensure Close is idempotent
err = c.Close(ctx)
@ -668,12 +668,12 @@ func TestBasicTelemetryController_Success(t *testing.T) {
})
}()
call := testutil.RequireRecvCtx(ctx, t, ft.calls)
call := testutil.TryReceive(ctx, t, ft.calls)
require.Len(t, call.req.GetEvents(), 1)
require.Equal(t, call.req.GetEvents()[0].GetId(), []byte("test event"))
testutil.RequireSendCtx(ctx, t, call.errCh, nil)
testutil.RequireRecvCtx(ctx, t, sendDone)
testutil.RequireSend(ctx, t, call.errCh, nil)
testutil.TryReceive(ctx, t, sendDone)
}
func TestBasicTelemetryController_Unimplemented(t *testing.T) {
@ -695,9 +695,9 @@ func TestBasicTelemetryController_Unimplemented(t *testing.T) {
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
}()
call := testutil.RequireRecvCtx(ctx, t, ft.calls)
testutil.RequireSendCtx(ctx, t, call.errCh, telemetryError)
testutil.RequireRecvCtx(ctx, t, sendDone)
call := testutil.TryReceive(ctx, t, ft.calls)
testutil.RequireSend(ctx, t, call.errCh, telemetryError)
testutil.TryReceive(ctx, t, sendDone)
sendDone = make(chan struct{})
go func() {
@ -706,12 +706,12 @@ func TestBasicTelemetryController_Unimplemented(t *testing.T) {
}()
// we get another call since it wasn't really the Unimplemented error
call = testutil.RequireRecvCtx(ctx, t, ft.calls)
call = testutil.TryReceive(ctx, t, ft.calls)
// for real this time
telemetryError = errUnimplemented
testutil.RequireSendCtx(ctx, t, call.errCh, telemetryError)
testutil.RequireRecvCtx(ctx, t, sendDone)
testutil.RequireSend(ctx, t, call.errCh, telemetryError)
testutil.TryReceive(ctx, t, sendDone)
// now this returns immediately without a call, because unimplemented error disables calling
sendDone = make(chan struct{})
@ -719,7 +719,7 @@ func TestBasicTelemetryController_Unimplemented(t *testing.T) {
defer close(sendDone)
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
}()
testutil.RequireRecvCtx(ctx, t, sendDone)
testutil.TryReceive(ctx, t, sendDone)
// getting a "new" client resets
uut.New(ft)
@ -728,9 +728,9 @@ func TestBasicTelemetryController_Unimplemented(t *testing.T) {
defer close(sendDone)
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
}()
call = testutil.RequireRecvCtx(ctx, t, ft.calls)
testutil.RequireSendCtx(ctx, t, call.errCh, nil)
testutil.RequireRecvCtx(ctx, t, sendDone)
call = testutil.TryReceive(ctx, t, ft.calls)
testutil.RequireSend(ctx, t, call.errCh, nil)
testutil.TryReceive(ctx, t, sendDone)
}
func TestBasicTelemetryController_NotRecognised(t *testing.T) {
@ -747,20 +747,20 @@ func TestBasicTelemetryController_NotRecognised(t *testing.T) {
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
}()
// returning generic protocol error doesn't trigger unknown rpc logic
call := testutil.RequireRecvCtx(ctx, t, ft.calls)
testutil.RequireSendCtx(ctx, t, call.errCh, drpc.ProtocolError.New("Protocol Error"))
testutil.RequireRecvCtx(ctx, t, sendDone)
call := testutil.TryReceive(ctx, t, ft.calls)
testutil.RequireSend(ctx, t, call.errCh, drpc.ProtocolError.New("Protocol Error"))
testutil.TryReceive(ctx, t, sendDone)
sendDone = make(chan struct{})
go func() {
defer close(sendDone)
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
}()
call = testutil.RequireRecvCtx(ctx, t, ft.calls)
call = testutil.TryReceive(ctx, t, ft.calls)
// return the expected protocol error this time
testutil.RequireSendCtx(ctx, t, call.errCh,
testutil.RequireSend(ctx, t, call.errCh,
drpc.ProtocolError.New("unknown rpc: /coder.tailnet.v2.Tailnet/PostTelemetry"))
testutil.RequireRecvCtx(ctx, t, sendDone)
testutil.TryReceive(ctx, t, sendDone)
// now this returns immediately without a call, because unimplemented error disables calling
sendDone = make(chan struct{})
@ -768,7 +768,7 @@ func TestBasicTelemetryController_NotRecognised(t *testing.T) {
defer close(sendDone)
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
}()
testutil.RequireRecvCtx(ctx, t, sendDone)
testutil.TryReceive(ctx, t, sendDone)
}
type fakeTelemetryClient struct {
@ -822,8 +822,8 @@ func TestBasicResumeTokenController_Mainline(t *testing.T) {
go func() {
cwCh <- uut.New(fr)
}()
call := testutil.RequireRecvCtx(ctx, t, fr.calls)
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
call := testutil.TryReceive(ctx, t, fr.calls)
testutil.RequireSend(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
Token: "test token 1",
RefreshIn: durationpb.New(100 * time.Second),
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
@ -832,11 +832,11 @@ func TestBasicResumeTokenController_Mainline(t *testing.T) {
token, ok := uut.Token()
require.True(t, ok)
require.Equal(t, "test token 1", token)
cw := testutil.RequireRecvCtx(ctx, t, cwCh)
cw := testutil.TryReceive(ctx, t, cwCh)
w := mClock.Advance(100 * time.Second)
call = testutil.RequireRecvCtx(ctx, t, fr.calls)
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
call = testutil.TryReceive(ctx, t, fr.calls)
testutil.RequireSend(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
Token: "test token 2",
RefreshIn: durationpb.New(50 * time.Second),
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
@ -851,7 +851,7 @@ func TestBasicResumeTokenController_Mainline(t *testing.T) {
err := cw.Close(ctx)
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, cw.Wait())
err = testutil.TryReceive(ctx, t, cw.Wait())
require.NoError(t, err)
token, ok = uut.Token()
@ -880,24 +880,24 @@ func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) {
go func() {
cwCh1 <- uut.New(fr1)
}()
call1 := testutil.RequireRecvCtx(ctx, t, fr1.calls)
call1 := testutil.TryReceive(ctx, t, fr1.calls)
fr2 := newFakeResumeTokenClient(ctx)
cwCh2 := make(chan tailnet.CloserWaiter, 1)
go func() {
cwCh2 <- uut.New(fr2)
}()
call2 := testutil.RequireRecvCtx(ctx, t, fr2.calls)
call2 := testutil.TryReceive(ctx, t, fr2.calls)
testutil.RequireSendCtx(ctx, t, call2.resp, &proto.RefreshResumeTokenResponse{
testutil.RequireSend(ctx, t, call2.resp, &proto.RefreshResumeTokenResponse{
Token: "test token 2.0",
RefreshIn: durationpb.New(102 * time.Second),
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
})
cw2 := testutil.RequireRecvCtx(ctx, t, cwCh2) // this ensures Close was called on 1
cw2 := testutil.TryReceive(ctx, t, cwCh2) // this ensures Close was called on 1
testutil.RequireSendCtx(ctx, t, call1.resp, &proto.RefreshResumeTokenResponse{
testutil.RequireSend(ctx, t, call1.resp, &proto.RefreshResumeTokenResponse{
Token: "test token 1",
RefreshIn: durationpb.New(101 * time.Second),
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
@ -910,13 +910,13 @@ func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) {
require.Equal(t, "test token 2.0", token)
// refresher 1 should already be closed.
cw1 := testutil.RequireRecvCtx(ctx, t, cwCh1)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
cw1 := testutil.TryReceive(ctx, t, cwCh1)
err := testutil.TryReceive(ctx, t, cw1.Wait())
require.NoError(t, err)
w := mClock.Advance(102 * time.Second)
call := testutil.RequireRecvCtx(ctx, t, fr2.calls)
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
call := testutil.TryReceive(ctx, t, fr2.calls)
testutil.RequireSend(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
Token: "test token 2.1",
RefreshIn: durationpb.New(50 * time.Second),
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
@ -931,7 +931,7 @@ func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) {
err = cw2.Close(ctx)
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
err = testutil.TryReceive(ctx, t, cw2.Wait())
require.NoError(t, err)
}
@ -948,9 +948,9 @@ func TestBasicResumeTokenController_Unimplemented(t *testing.T) {
fr := newFakeResumeTokenClient(ctx)
cw := uut.New(fr)
call := testutil.RequireRecvCtx(ctx, t, fr.calls)
testutil.RequireSendCtx(ctx, t, call.errCh, errUnimplemented)
err := testutil.RequireRecvCtx(ctx, t, cw.Wait())
call := testutil.TryReceive(ctx, t, fr.calls)
testutil.RequireSend(ctx, t, call.errCh, errUnimplemented)
err := testutil.TryReceive(ctx, t, cw.Wait())
require.NoError(t, err)
_, ok = uut.Token()
require.False(t, ok)
@ -1044,35 +1044,35 @@ func TestController_Disconnects(t *testing.T) {
uut.DERPCtrl = tailnet.NewBasicDERPController(logger.Named("derp_ctrl"), fConn)
uut.Run(ctx)
call := testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
call := testutil.TryReceive(testCtx, t, fCoord.CoordinateCalls)
// simulate a problem with DERPMaps by sending nil
testutil.RequireSendCtx(testCtx, t, derpMapCh, nil)
testutil.RequireSend(testCtx, t, derpMapCh, nil)
// this should cause the coordinate call to hang up WITHOUT disconnecting
reqNil := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
reqNil := testutil.TryReceive(testCtx, t, call.Reqs)
require.Nil(t, reqNil)
// and mark all peers lost
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
_ = testutil.TryReceive(testCtx, t, peersLost)
// ...and then reconnect
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
call = testutil.TryReceive(testCtx, t, fCoord.CoordinateCalls)
// close the coordination call, which should cause a 2nd reconnection
close(call.Resps)
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
_ = testutil.TryReceive(testCtx, t, peersLost)
call = testutil.TryReceive(testCtx, t, fCoord.CoordinateCalls)
// canceling the context should trigger the disconnect message
cancel()
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
reqDisc := testutil.TryReceive(testCtx, t, call.Reqs)
require.NotNil(t, reqDisc)
require.NotNil(t, reqDisc.Disconnect)
close(call.Resps)
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
_ = testutil.RequireRecvCtx(testCtx, t, uut.Closed())
_ = testutil.TryReceive(testCtx, t, peersLost)
_ = testutil.TryReceive(testCtx, t, uut.Closed())
}
func TestController_TelemetrySuccess(t *testing.T) {
@ -1124,14 +1124,14 @@ func TestController_TelemetrySuccess(t *testing.T) {
uut.Run(ctx)
// Coordinate calls happen _after_ telemetry is connected up, so we use this
// to ensure telemetry is connected before sending our event
cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
cc := testutil.TryReceive(ctx, t, fCoord.CoordinateCalls)
defer close(cc.Resps)
tel.SendTelemetryEvent(&proto.TelemetryEvent{
Id: []byte("test event"),
})
testEvents := testutil.RequireRecvCtx(ctx, t, eventCh)
testEvents := testutil.TryReceive(ctx, t, eventCh)
require.Len(t, testEvents, 1)
require.Equal(t, []byte("test event"), testEvents[0].Id)
@ -1157,27 +1157,27 @@ func TestController_WorkspaceUpdates(t *testing.T) {
uut.Run(ctx)
// it should dial and pass the client to the controller
call := testutil.RequireRecvCtx(testCtx, t, fCtrl.calls)
call := testutil.TryReceive(testCtx, t, fCtrl.calls)
require.Equal(t, fClient, call.client)
fCW := newFakeCloserWaiter()
testutil.RequireSendCtx[tailnet.CloserWaiter](testCtx, t, call.resp, fCW)
testutil.RequireSend[tailnet.CloserWaiter](testCtx, t, call.resp, fCW)
// if the CloserWaiter exits...
testutil.RequireSendCtx(testCtx, t, fCW.errCh, theError)
testutil.RequireSend(testCtx, t, fCW.errCh, theError)
// it should close, redial and reconnect
cCall := testutil.RequireRecvCtx(testCtx, t, fClient.close)
testutil.RequireSendCtx(testCtx, t, cCall, nil)
cCall := testutil.TryReceive(testCtx, t, fClient.close)
testutil.RequireSend(testCtx, t, cCall, nil)
call = testutil.RequireRecvCtx(testCtx, t, fCtrl.calls)
call = testutil.TryReceive(testCtx, t, fCtrl.calls)
require.Equal(t, fClient, call.client)
fCW = newFakeCloserWaiter()
testutil.RequireSendCtx[tailnet.CloserWaiter](testCtx, t, call.resp, fCW)
testutil.RequireSend[tailnet.CloserWaiter](testCtx, t, call.resp, fCW)
// canceling the context should close the client
cancel()
cCall = testutil.RequireRecvCtx(testCtx, t, fClient.close)
testutil.RequireSendCtx(testCtx, t, cCall, nil)
cCall = testutil.TryReceive(testCtx, t, fClient.close)
testutil.RequireSend(testCtx, t, cCall, nil)
}
type fakeTailnetConn struct {
@ -1492,12 +1492,12 @@ func setupConnectedAllWorkspaceUpdatesController(
coordCW := tsc.New(coordC)
t.Cleanup(func() {
// hang up coord client
coordRecv := testutil.RequireRecvCtx(ctx, t, coordC.resps)
testutil.RequireSendCtx(ctx, t, coordRecv.err, io.EOF)
coordRecv := testutil.TryReceive(ctx, t, coordC.resps)
testutil.RequireSend(ctx, t, coordRecv.err, io.EOF)
// sends close on client
cCall := testutil.RequireRecvCtx(ctx, t, coordC.close)
testutil.RequireSendCtx(ctx, t, cCall, nil)
err := testutil.RequireRecvCtx(ctx, t, coordCW.Wait())
cCall := testutil.TryReceive(ctx, t, coordC.close)
testutil.RequireSend(ctx, t, cCall, nil)
err := testutil.TryReceive(ctx, t, coordCW.Wait())
require.ErrorIs(t, err, io.EOF)
})
@ -1506,9 +1506,9 @@ func setupConnectedAllWorkspaceUpdatesController(
updateCW := uut.New(updateC)
t.Cleanup(func() {
// hang up WorkspaceUpdates client
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
testutil.RequireSendCtx(ctx, t, upRecvCall.err, io.EOF)
err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
upRecvCall := testutil.TryReceive(ctx, t, updateC.recv)
testutil.RequireSend(ctx, t, upRecvCall.err, io.EOF)
err := testutil.TryReceive(ctx, t, updateCW.Wait())
require.ErrorIs(t, err, io.EOF)
})
return coordC, updateC, uut
@ -1544,15 +1544,15 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
},
}
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp)
upRecvCall := testutil.TryReceive(ctx, t, updateC.recv)
testutil.RequireSend(ctx, t, upRecvCall.resp, initUp)
// This should trigger AddTunnel for each agent
var adds []uuid.UUID
for range 3 {
coordCall := testutil.RequireRecvCtx(ctx, t, coordC.reqs)
coordCall := testutil.TryReceive(ctx, t, coordC.reqs)
adds = append(adds, uuid.Must(uuid.FromBytes(coordCall.req.GetAddTunnel().GetId())))
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
testutil.RequireSend(ctx, t, coordCall.err, nil)
}
require.Contains(t, adds, w1a1ID)
require.Contains(t, adds, w2a1ID)
@ -1576,9 +1576,9 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
"w1.mctest.": {ws1a1IP},
expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
}
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
dnsCall := testutil.TryReceive(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
testutil.RequireSend(ctx, t, dnsCall.err, nil)
currentState := tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
@ -1614,7 +1614,7 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
}
// And the callback
cbUpdate := testutil.RequireRecvCtx(ctx, t, fUH.ch)
cbUpdate := testutil.TryReceive(ctx, t, fUH.ch)
require.Equal(t, currentState, cbUpdate)
// Current recvState should match
@ -1656,13 +1656,13 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
},
}
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp)
upRecvCall := testutil.TryReceive(ctx, t, updateC.recv)
testutil.RequireSend(ctx, t, upRecvCall.resp, initUp)
// Add for w1a1
coordCall := testutil.RequireRecvCtx(ctx, t, coordC.reqs)
coordCall := testutil.TryReceive(ctx, t, coordC.reqs)
require.Equal(t, w1a1ID[:], coordCall.req.GetAddTunnel().GetId())
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
testutil.RequireSend(ctx, t, coordCall.err, nil)
expectedCoderConnectFQDN, err := dnsname.ToFQDN(
fmt.Sprintf(tailnet.IsCoderConnectEnabledFmtString, tailnet.CoderDNSSuffix))
@ -1675,9 +1675,9 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
"w1.coder.": {ws1a1IP},
expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
}
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
dnsCall := testutil.TryReceive(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
testutil.RequireSend(ctx, t, dnsCall.err, nil)
initRecvUp := tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
@ -1694,7 +1694,7 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
DeletedAgents: []*tailnet.Agent{},
}
cbUpdate := testutil.RequireRecvCtx(ctx, t, fUH.ch)
cbUpdate := testutil.TryReceive(ctx, t, fUH.ch)
require.Equal(t, initRecvUp, cbUpdate)
// Current state should match initial
@ -1711,18 +1711,18 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
{Id: w1a1ID[:], WorkspaceId: w1ID[:]},
},
}
upRecvCall = testutil.RequireRecvCtx(ctx, t, updateC.recv)
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, agentUpdate)
upRecvCall = testutil.TryReceive(ctx, t, updateC.recv)
testutil.RequireSend(ctx, t, upRecvCall.resp, agentUpdate)
// Add for w1a2
coordCall = testutil.RequireRecvCtx(ctx, t, coordC.reqs)
coordCall = testutil.TryReceive(ctx, t, coordC.reqs)
require.Equal(t, w1a2ID[:], coordCall.req.GetAddTunnel().GetId())
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
testutil.RequireSend(ctx, t, coordCall.err, nil)
// Remove for w1a1
coordCall = testutil.RequireRecvCtx(ctx, t, coordC.reqs)
coordCall = testutil.TryReceive(ctx, t, coordC.reqs)
require.Equal(t, w1a1ID[:], coordCall.req.GetRemoveTunnel().GetId())
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
testutil.RequireSend(ctx, t, coordCall.err, nil)
// DNS contains only w1a2
expectedDNS = map[dnsname.FQDN][]netip.Addr{
@ -1731,11 +1731,11 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
"w1.coder.": {ws1a2IP},
expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
}
dnsCall = testutil.RequireRecvCtx(ctx, t, fDNS.calls)
dnsCall = testutil.TryReceive(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
testutil.RequireSend(ctx, t, dnsCall.err, nil)
cbUpdate = testutil.RequireRecvCtx(ctx, t, fUH.ch)
cbUpdate = testutil.TryReceive(ctx, t, fUH.ch)
sndRecvUpdate := tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{},
UpsertedAgents: []*tailnet.Agent{
@ -1804,8 +1804,8 @@ func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {
{Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]},
},
}
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp)
upRecvCall := testutil.TryReceive(ctx, t, updateC.recv)
testutil.RequireSend(ctx, t, upRecvCall.resp, initUp)
expectedCoderConnectFQDN, err := dnsname.ToFQDN(
fmt.Sprintf(tailnet.IsCoderConnectEnabledFmtString, tailnet.CoderDNSSuffix))
@ -1818,16 +1818,16 @@ func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {
"w1.coder.": {ws1a1IP},
expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
}
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
dnsCall := testutil.TryReceive(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts)
testutil.RequireSendCtx(ctx, t, dnsCall.err, dnsError)
testutil.RequireSend(ctx, t, dnsCall.err, dnsError)
// should trigger a close on the client
closeCall := testutil.RequireRecvCtx(ctx, t, updateC.close)
testutil.RequireSendCtx(ctx, t, closeCall, io.EOF)
closeCall := testutil.TryReceive(ctx, t, updateC.close)
testutil.RequireSend(ctx, t, closeCall, io.EOF)
// error should be our initial DNS error
err = testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
err = testutil.TryReceive(ctx, t, updateCW.Wait())
require.ErrorIs(t, err, dnsError)
}
@ -1927,12 +1927,12 @@ func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) {
updateC := newFakeWorkspaceUpdateClient(ctx, t)
updateCW := uut.New(updateC)
recvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
testutil.RequireSendCtx(ctx, t, recvCall.resp, tc.update)
closeCall := testutil.RequireRecvCtx(ctx, t, updateC.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
recvCall := testutil.TryReceive(ctx, t, updateC.recv)
testutil.RequireSend(ctx, t, recvCall.resp, tc.update)
closeCall := testutil.TryReceive(ctx, t, updateC.close)
testutil.RequireSend(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
err := testutil.TryReceive(ctx, t, updateCW.Wait())
require.ErrorContains(t, err, tc.errorContains)
})
}