fix: fix tailnet remoteCoordination to wait for server (#14666)

Fixes #12560

When gracefully disconnecting from the coordinator, we would send the Disconnect message and then close the dRPC stream.  However, closing the dRPC stream can cause the server not to process the Disconnect message, since we use the stream context in a `select` while sending it to the coordinator.

This is a product bug uncovered by the flake, and probably results in us failing graceful disconnect some minority of the time.

Instead, the `remoteCoordination` (and `inMemoryCoordination` for consistency) should send the Disconnect message and then wait for the coordinator to hang up (on some graceful disconnect timer, in the form of a context).
This commit is contained in:
Spike Curtis
2024-09-16 09:24:30 +04:00
committed by GitHub
parent 7ea8a2253e
commit 2df9a3e554
7 changed files with 72 additions and 49 deletions

View File

@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
defer close(errCh) defer close(errCh)
select { select {
case <-ctx.Done(): case <-ctx.Done():
err := coordination.Close() err := coordination.Close(a.hardCtx)
if err != nil { if err != nil {
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
} }

View File

@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
coordinator, conn) coordinator, conn)
t.Cleanup(func() { t.Cleanup(func() {
t.Logf("closing coordination %s", name) t.Logf("closing coordination %s", name)
err := coordination.Close() cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
if err != nil { if err != nil {
t.Logf("error closing in-memory coordination: %s", err.Error()) t.Logf("error closing in-memory coordination: %s", err.Error())
} }
@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
clientID, metadata.AgentID, clientID, metadata.AgentID,
coordinator, conn) coordinator, conn)
t.Cleanup(func() { t.Cleanup(func() {
err := coordination.Close() cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
if err != nil { if err != nil {
t.Logf("error closing in-mem coordination: %s", err.Error()) t.Logf("error closing in-mem coordination: %s", err.Error())
} }

View File

@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
select { select {
case <-tac.ctx.Done(): case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect") tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
crdErr := coordination.Close() crdErr := coordination.Close(tac.gracefulCtx)
if crdErr != nil { if crdErr != nil {
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
} }

View File

@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
derpMapCh := make(chan *tailcfg.DERPMap) derpMapCh := make(chan *tailcfg.DERPMap)
defer close(derpMapCh) defer close(derpMapCh)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger, Logger: logger.Named("svc"),
CoordPtr: &coordPtr, CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Millisecond, DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
fConn := newFakeTailnetConn() fConn := newFakeTailnetConn()
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{}) uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL,
quartz.NewReal(), &websocket.DialOptions{})
uut.runConnector(fConn) uut.runConnector(fConn)
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
require.NotNil(t, reqDisc) require.NotNil(t, reqDisc)
require.NotNil(t, reqDisc.Disconnect) require.NotNil(t, reqDisc.Disconnect)
close(call.Resps)
} }
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) { func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {

View File

@ -91,7 +91,7 @@ type Coordinatee interface {
} }
type Coordination interface { type Coordination interface {
io.Closer Close(context.Context) error
Error() <-chan error Error() <-chan error
} }
@ -106,7 +106,10 @@ type remoteCoordination struct {
respLoopDone chan struct{} respLoopDone chan struct{}
} }
func (c *remoteCoordination) Close() (retErr error) { // Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
// waiting for the server to hang up the coordination. If the provided context expires, we stop
// waiting for the server and close the coordination stream from our end.
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if c.closed { if c.closed {
@ -114,6 +117,18 @@ func (c *remoteCoordination) Close() (retErr error) {
} }
c.closed = true c.closed = true
defer func() { defer func() {
// We shouldn't just close the protocol right away, because the way dRPC streams work is
// that if you close them, that could take effect immediately, even before the Disconnect
// message is processed. Coordinators are supposed to hang up on us once they get a
// Disconnect message, so we should wait around for that until the context expires.
select {
case <-c.respLoopDone:
c.logger.Debug(ctx, "responses closed after disconnect")
return
case <-ctx.Done():
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.protocol.Close() protoErr := c.protocol.Close()
<-c.respLoopDone <-c.respLoopDone
if retErr == nil { if retErr == nil {
@ -240,7 +255,6 @@ type inMemoryCoordination struct {
ctx context.Context ctx context.Context
errChan chan error errChan chan error
closed bool closed bool
closedCh chan struct{}
respLoopDone chan struct{} respLoopDone chan struct{}
coordinatee Coordinatee coordinatee Coordinatee
logger slog.Logger logger slog.Logger
@ -280,7 +294,6 @@ func NewInMemoryCoordination(
errChan: make(chan error, 1), errChan: make(chan error, 1),
coordinatee: coordinatee, coordinatee: coordinatee,
logger: logger, logger: logger,
closedCh: make(chan struct{}),
respLoopDone: make(chan struct{}), respLoopDone: make(chan struct{}),
} }
@ -328,24 +341,15 @@ func (c *inMemoryCoordination) respLoop() {
c.coordinatee.SetAllPeersLost() c.coordinatee.SetAllPeersLost()
close(c.respLoopDone) close(c.respLoopDone)
}() }()
for { for resp := range c.resps {
select { c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
case <-c.closedCh: err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
c.logger.Debug(context.Background(), "in-memory coordination closed") if err != nil {
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
return return
case resp, ok := <-c.resps:
if !ok {
c.logger.Debug(context.Background(), "in-memory response channel closed")
return
}
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
return
}
} }
} }
c.logger.Debug(context.Background(), "in-memory response channel closed")
} }
func (*inMemoryCoordination) AwaitAck() <-chan struct{} { func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
return ch return ch
} }
func (c *inMemoryCoordination) Close() error { // Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
// waiting for the server to hang up the coordination. If the provided context expires, we stop
// waiting for the server and close the coordination stream from our end.
func (c *inMemoryCoordination) Close(ctx context.Context) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
c.logger.Debug(context.Background(), "closing in-memory coordination") c.logger.Debug(context.Background(), "closing in-memory coordination")
@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error {
} }
defer close(c.reqs) defer close(c.reqs)
c.closed = true c.closed = true
close(c.closedCh)
<-c.respLoopDone
select { select {
case <-c.ctx.Done(): case <-ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}: case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory") c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
}
select {
case <-ctx.Done():
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
case <-c.respLoopDone:
return nil return nil
} }
} }

View File

@ -2,6 +2,7 @@ package tailnet_test
import ( import (
"context" "context"
"io"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) {
Times(1).Return(reqs, resps) Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close() defer uut.Close(ctx)
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID) uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close() defer uut.Close(ctx)
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select { // Recv loop should be terminated by the server hanging up after Disconnect
case err := <-uut.Error(): err = testutil.RequireRecvCtx(ctx, t, uut.Error())
require.ErrorContains(t, err, "stream terminated by sending close") require.ErrorIs(t, err, io.EOF)
default:
// OK!
}
} }
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{}) uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
defer uut.Close() defer uut.Close(ctx)
nk, err := key.NewNode().Public().MarshalBinary() nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err) require.NoError(t, err)
@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
require.Len(t, rfh.ReadyForHandshake, 1) require.Len(t, rfh.ReadyForHandshake, 1)
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id) require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
require.NoError(t, uut.Close()) go uut.Close(ctx)
dis := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, dis)
require.NotNil(t, dis.Disconnect)
close(resps)
select { // Recv loop should be terminated by the server hanging up after Disconnect
case err := <-uut.Error(): err = testutil.RequireRecvCtx(ctx, t, uut.Error())
require.ErrorContains(t, err, "stream terminated by sending close") require.ErrorIs(t, err, io.EOF)
default:
// OK!
}
} }
// coordinationTest tests that a coordination behaves correctly // coordinationTest tests that a coordination behaves correctly
@ -464,13 +463,18 @@ func coordinationTest(
require.Len(t, fConn.updates[0], 1) require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id) require.Equal(t, agentID[:], fConn.updates[0][0].Id)
err = uut.Close() errCh := make(chan error, 1)
require.NoError(t, err) go func() {
uut.Error() errCh <- uut.Close(ctx)
}()
// When we close, it should gracefully disconnect // When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs) req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect) require.NotNil(t, req.Disconnect)
close(resps)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
// It should set all peers lost on the coordinatee // It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls) require.Equal(t, 1, fConn.setAllPeersLostCalls)

View File

@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me
coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID) coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID)
t.Cleanup(func() { t.Cleanup(func() {
_ = coordination.Close() cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
_ = coordination.Close(cctx)
}) })
return conn return conn