From 2df9a3e55478f29c750caa97f6c5fbe021f3943a Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Mon, 16 Sep 2024 09:24:30 +0400 Subject: [PATCH] fix: fix tailnet remoteCoordination to wait for server (#14666) Fixes #12560 When gracefully disconnecting from the coordinator, we would send the Disconnect message and then close the dRPC stream. However, closing the dRPC stream can cause the server not to process the Disconnect message, since we use the stream context in a `select` while sending it to the coordinator. This is a product bug uncovered by the flake, and probably results in us failing graceful disconnect some minority of the time. Instead, the `remoteCoordination` (and `inMemoryCoordination` for consistency) should send the Disconnect message and then wait for the coordinator to hang up (on some graceful disconnect timer, in the form of a context). --- agent/agent.go | 2 +- agent/agent_test.go | 8 ++- codersdk/workspacesdk/connector.go | 2 +- .../workspacesdk/connector_internal_test.go | 6 +- tailnet/coordinator.go | 57 +++++++++++-------- tailnet/coordinator_test.go | 42 +++++++------- tailnet/test/integration/integration.go | 4 +- 7 files changed, 72 insertions(+), 49 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 98e294320b..2194e04dd1 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai defer close(errCh) select { case <-ctx.Done(): - err := coordination.Close() + err := coordination.Close(a.hardCtx) if err != nil { a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) } diff --git a/agent/agent_test.go b/agent/agent_test.go index e4aac04e0e..91e7c1c34e 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) { coordinator, conn) t.Cleanup(func() { t.Logf("closing coordination %s", name) - err := coordination.Close() + cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) + defer ccancel() + err := coordination.Close(cctx) if err != nil { t.Logf("error closing in-memory coordination: %s", err.Error()) } @@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati clientID, metadata.AgentID, coordinator, conn) t.Cleanup(func() { - err := coordination.Close() + cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) + defer ccancel() + err := coordination.Close(cctx) if err != nil { t.Logf("error closing in-mem coordination: %s", err.Error()) } diff --git a/codersdk/workspacesdk/connector.go b/codersdk/workspacesdk/connector.go index c761c92ae3..780478e91a 100644 --- a/codersdk/workspacesdk/connector.go +++ b/codersdk/workspacesdk/connector.go @@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { select { case <-tac.ctx.Done(): tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") - crdErr := coordination.Close() + crdErr := coordination.Close(tac.gracefulCtx) if crdErr != nil { tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) } diff --git a/codersdk/workspacesdk/connector_internal_test.go b/codersdk/workspacesdk/connector_internal_test.go index d56f45b482..7a339a0079 100644 --- a/codersdk/workspacesdk/connector_internal_test.go +++ b/codersdk/workspacesdk/connector_internal_test.go @@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { derpMapCh := make(chan *tailcfg.DERPMap) defer close(derpMapCh) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger, + Logger: logger.Named("svc"), CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, @@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { fConn := newFakeTailnetConn() - uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{}) + uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL, + quartz.NewReal(), &websocket.DialOptions{}) uut.runConnector(fConn) call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) @@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) { reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) require.NotNil(t, reqDisc) require.NotNil(t, reqDisc.Disconnect) + close(call.Resps) } func TestTailnetAPIConnector_UplevelVersion(t *testing.T) { diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index cc50c792f1..54ce868df9 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -91,7 +91,7 @@ type Coordinatee interface { } type Coordination interface { - io.Closer + Close(context.Context) error Error() <-chan error } @@ -106,7 +106,10 @@ type remoteCoordination struct { respLoopDone chan struct{} } -func (c *remoteCoordination) Close() (retErr error) { +// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and +// waiting for the server to hang up the coordination. If the provided context expires, we stop +// waiting for the server and close the coordination stream from our end. +func (c *remoteCoordination) Close(ctx context.Context) (retErr error) { c.Lock() defer c.Unlock() if c.closed { @@ -114,6 +117,18 @@ func (c *remoteCoordination) Close() (retErr error) { } c.closed = true defer func() { + // We shouldn't just close the protocol right away, because the way dRPC streams work is + // that if you close them, that could take effect immediately, even before the Disconnect + // message is processed. Coordinators are supposed to hang up on us once they get a + // Disconnect message, so we should wait around for that until the context expires. + select { + case <-c.respLoopDone: + c.logger.Debug(ctx, "responses closed after disconnect") + return + case <-ctx.Done(): + c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") + } + // forcefully close the stream protoErr := c.protocol.Close() <-c.respLoopDone if retErr == nil { @@ -240,7 +255,6 @@ type inMemoryCoordination struct { ctx context.Context errChan chan error closed bool - closedCh chan struct{} respLoopDone chan struct{} coordinatee Coordinatee logger slog.Logger @@ -280,7 +294,6 @@ func NewInMemoryCoordination( errChan: make(chan error, 1), coordinatee: coordinatee, logger: logger, - closedCh: make(chan struct{}), respLoopDone: make(chan struct{}), } @@ -328,24 +341,15 @@ func (c *inMemoryCoordination) respLoop() { c.coordinatee.SetAllPeersLost() close(c.respLoopDone) }() - for { - select { - case <-c.closedCh: - c.logger.Debug(context.Background(), "in-memory coordination closed") + for resp := range c.resps { + c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) + err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) return - case resp, ok := <-c.resps: - if !ok { - c.logger.Debug(context.Background(), "in-memory response channel closed") - return - } - c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) - err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) - if err != nil { - c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) - return - } } } + c.logger.Debug(context.Background(), "in-memory response channel closed") } func (*inMemoryCoordination) AwaitAck() <-chan struct{} { @@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} { return ch } -func (c *inMemoryCoordination) Close() error { +// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and +// waiting for the server to hang up the coordination. If the provided context expires, we stop +// waiting for the server and close the coordination stream from our end. +func (c *inMemoryCoordination) Close(ctx context.Context) error { c.Lock() defer c.Unlock() c.logger.Debug(context.Background(), "closing in-memory coordination") @@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error { } defer close(c.reqs) c.closed = true - close(c.closedCh) - <-c.respLoopDone select { - case <-c.ctx.Done(): + case <-ctx.Done(): return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}: c.logger.Debug(context.Background(), "sent graceful disconnect in-memory") + } + + select { + case <-ctx.Done(): + return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err()) + case <-c.respLoopDone: return nil } } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 400084fafa..99b4724e35 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -2,6 +2,7 @@ package tailnet_test import ( "context" + "io" "net" "net/netip" "sync" @@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) { Times(1).Return(reqs, resps) uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) - defer uut.Close() + defer uut.Close(ctx) coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) @@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) { require.NoError(t, err) uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID) - defer uut.Close() + defer uut.Close(ctx) coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) - select { - case err := <-uut.Error(): - require.ErrorContains(t, err, "stream terminated by sending close") - default: - // OK! - } + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Error()) + require.ErrorIs(t, err, io.EOF) } func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { @@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { require.NoError(t, err) uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{}) - defer uut.Close() + defer uut.Close(ctx) nk, err := key.NewNode().Public().MarshalBinary() require.NoError(t, err) @@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { require.Len(t, rfh.ReadyForHandshake, 1) require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id) - require.NoError(t, uut.Close()) + go uut.Close(ctx) + dis := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, dis) + require.NotNil(t, dis.Disconnect) + close(resps) - select { - case err := <-uut.Error(): - require.ErrorContains(t, err, "stream terminated by sending close") - default: - // OK! - } + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Error()) + require.ErrorIs(t, err, io.EOF) } // coordinationTest tests that a coordination behaves correctly @@ -464,13 +463,18 @@ func coordinationTest( require.Len(t, fConn.updates[0], 1) require.Equal(t, agentID[:], fConn.updates[0][0].Id) - err = uut.Close() - require.NoError(t, err) - uut.Error() + errCh := make(chan error, 1) + go func() { + errCh <- uut.Close(ctx) + }() // When we close, it should gracefully disconnect req = testutil.RequireRecvCtx(ctx, t, reqs) require.NotNil(t, req.Disconnect) + close(resps) + + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.NoError(t, err) // It should set all peers lost on the coordinatee require.Equal(t, 1, fConn.setAllPeersLostCalls) diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index 41326caaa7..0d3956cf44 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID) t.Cleanup(func() { - _ = coordination.Close() + cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + _ = coordination.Close(cctx) }) return conn