mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
fix: fix tailnet remoteCoordination to wait for server (#14666)
Fixes #12560 When gracefully disconnecting from the coordinator, we would send the Disconnect message and then close the dRPC stream. However, closing the dRPC stream can cause the server not to process the Disconnect message, since we use the stream context in a `select` while sending it to the coordinator. This is a product bug uncovered by the flake, and probably results in us failing graceful disconnect some minority of the time. Instead, the `remoteCoordination` (and `inMemoryCoordination` for consistency) should send the Disconnect message and then wait for the coordinator to hang up (on some graceful disconnect timer, in the form of a context).
This commit is contained in:
@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
|
||||
defer close(errCh)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := coordination.Close()
|
||||
err := coordination.Close(a.hardCtx)
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
|
||||
}
|
||||
|
@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
||||
coordinator, conn)
|
||||
t.Cleanup(func() {
|
||||
t.Logf("closing coordination %s", name)
|
||||
err := coordination.Close()
|
||||
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||||
defer ccancel()
|
||||
err := coordination.Close(cctx)
|
||||
if err != nil {
|
||||
t.Logf("error closing in-memory coordination: %s", err.Error())
|
||||
}
|
||||
@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
|
||||
clientID, metadata.AgentID,
|
||||
coordinator, conn)
|
||||
t.Cleanup(func() {
|
||||
err := coordination.Close()
|
||||
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||||
defer ccancel()
|
||||
err := coordination.Close(cctx)
|
||||
if err != nil {
|
||||
t.Logf("error closing in-mem coordination: %s", err.Error())
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
|
||||
select {
|
||||
case <-tac.ctx.Done():
|
||||
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
|
||||
crdErr := coordination.Close()
|
||||
crdErr := coordination.Close(tac.gracefulCtx)
|
||||
if crdErr != nil {
|
||||
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||
defer close(derpMapCh)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
Logger: logger.Named("svc"),
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL,
|
||||
quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||
@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
||||
require.NotNil(t, reqDisc)
|
||||
require.NotNil(t, reqDisc.Disconnect)
|
||||
close(call.Resps)
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
||||
|
@ -91,7 +91,7 @@ type Coordinatee interface {
|
||||
}
|
||||
|
||||
type Coordination interface {
|
||||
io.Closer
|
||||
Close(context.Context) error
|
||||
Error() <-chan error
|
||||
}
|
||||
|
||||
@ -106,7 +106,10 @@ type remoteCoordination struct {
|
||||
respLoopDone chan struct{}
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) Close() (retErr error) {
|
||||
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
|
||||
// waiting for the server to hang up the coordination. If the provided context expires, we stop
|
||||
// waiting for the server and close the coordination stream from our end.
|
||||
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
@ -114,6 +117,18 @@ func (c *remoteCoordination) Close() (retErr error) {
|
||||
}
|
||||
c.closed = true
|
||||
defer func() {
|
||||
// We shouldn't just close the protocol right away, because the way dRPC streams work is
|
||||
// that if you close them, that could take effect immediately, even before the Disconnect
|
||||
// message is processed. Coordinators are supposed to hang up on us once they get a
|
||||
// Disconnect message, so we should wait around for that until the context expires.
|
||||
select {
|
||||
case <-c.respLoopDone:
|
||||
c.logger.Debug(ctx, "responses closed after disconnect")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
|
||||
}
|
||||
// forcefully close the stream
|
||||
protoErr := c.protocol.Close()
|
||||
<-c.respLoopDone
|
||||
if retErr == nil {
|
||||
@ -240,7 +255,6 @@ type inMemoryCoordination struct {
|
||||
ctx context.Context
|
||||
errChan chan error
|
||||
closed bool
|
||||
closedCh chan struct{}
|
||||
respLoopDone chan struct{}
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
@ -280,7 +294,6 @@ func NewInMemoryCoordination(
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
closedCh: make(chan struct{}),
|
||||
respLoopDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
@ -328,16 +341,7 @@ func (c *inMemoryCoordination) respLoop() {
|
||||
c.coordinatee.SetAllPeersLost()
|
||||
close(c.respLoopDone)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-c.closedCh:
|
||||
c.logger.Debug(context.Background(), "in-memory coordination closed")
|
||||
return
|
||||
case resp, ok := <-c.resps:
|
||||
if !ok {
|
||||
c.logger.Debug(context.Background(), "in-memory response channel closed")
|
||||
return
|
||||
}
|
||||
for resp := range c.resps {
|
||||
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
|
||||
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
@ -345,7 +349,7 @@ func (c *inMemoryCoordination) respLoop() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.logger.Debug(context.Background(), "in-memory response channel closed")
|
||||
}
|
||||
|
||||
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
|
||||
@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
|
||||
return ch
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) Close() error {
|
||||
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
|
||||
// waiting for the server to hang up the coordination. If the provided context expires, we stop
|
||||
// waiting for the server and close the coordination stream from our end.
|
||||
func (c *inMemoryCoordination) Close(ctx context.Context) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.logger.Debug(context.Background(), "closing in-memory coordination")
|
||||
@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error {
|
||||
}
|
||||
defer close(c.reqs)
|
||||
c.closed = true
|
||||
close(c.closedCh)
|
||||
<-c.respLoopDone
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
||||
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
|
||||
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
|
||||
case <-c.respLoopDone:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package tailnet_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) {
|
||||
Times(1).Return(reqs, resps)
|
||||
|
||||
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
|
||||
defer uut.Close()
|
||||
defer uut.Close(ctx)
|
||||
|
||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||
|
||||
@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
|
||||
defer uut.Close()
|
||||
defer uut.Close(ctx)
|
||||
|
||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||
|
||||
select {
|
||||
case err := <-uut.Error():
|
||||
require.ErrorContains(t, err, "stream terminated by sending close")
|
||||
default:
|
||||
// OK!
|
||||
}
|
||||
// Recv loop should be terminated by the server hanging up after Disconnect
|
||||
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
||||
@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
|
||||
defer uut.Close()
|
||||
defer uut.Close(ctx)
|
||||
|
||||
nk, err := key.NewNode().Public().MarshalBinary()
|
||||
require.NoError(t, err)
|
||||
@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
||||
require.Len(t, rfh.ReadyForHandshake, 1)
|
||||
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
|
||||
|
||||
require.NoError(t, uut.Close())
|
||||
go uut.Close(ctx)
|
||||
dis := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.NotNil(t, dis)
|
||||
require.NotNil(t, dis.Disconnect)
|
||||
close(resps)
|
||||
|
||||
select {
|
||||
case err := <-uut.Error():
|
||||
require.ErrorContains(t, err, "stream terminated by sending close")
|
||||
default:
|
||||
// OK!
|
||||
}
|
||||
// Recv loop should be terminated by the server hanging up after Disconnect
|
||||
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
// coordinationTest tests that a coordination behaves correctly
|
||||
@ -464,13 +463,18 @@ func coordinationTest(
|
||||
require.Len(t, fConn.updates[0], 1)
|
||||
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
|
||||
|
||||
err = uut.Close()
|
||||
require.NoError(t, err)
|
||||
uut.Error()
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- uut.Close(ctx)
|
||||
}()
|
||||
|
||||
// When we close, it should gracefully disconnect
|
||||
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||
require.NotNil(t, req.Disconnect)
|
||||
close(resps)
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||
require.NoError(t, err)
|
||||
|
||||
// It should set all peers lost on the coordinatee
|
||||
require.Equal(t, 1, fConn.setAllPeersLostCalls)
|
||||
|
@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me
|
||||
|
||||
coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID)
|
||||
t.Cleanup(func() {
|
||||
_ = coordination.Close()
|
||||
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
_ = coordination.Close(cctx)
|
||||
})
|
||||
|
||||
return conn
|
||||
|
Reference in New Issue
Block a user