mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
fix: fix tailnet remoteCoordination to wait for server (#14666)
Fixes #12560 When gracefully disconnecting from the coordinator, we would send the Disconnect message and then close the dRPC stream. However, closing the dRPC stream can cause the server not to process the Disconnect message, since we use the stream context in a `select` while sending it to the coordinator. This is a product bug uncovered by the flake, and probably results in us failing graceful disconnect some minority of the time. Instead, the `remoteCoordination` (and `inMemoryCoordination` for consistency) should send the Disconnect message and then wait for the coordinator to hang up (on some graceful disconnect timer, in the form of a context).
This commit is contained in:
@ -1357,7 +1357,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
|
|||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
err := coordination.Close()
|
err := coordination.Close(a.hardCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
|
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
|
||||||
}
|
}
|
||||||
|
@ -1896,7 +1896,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
|
|||||||
coordinator, conn)
|
coordinator, conn)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
t.Logf("closing coordination %s", name)
|
t.Logf("closing coordination %s", name)
|
||||||
err := coordination.Close()
|
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||||||
|
defer ccancel()
|
||||||
|
err := coordination.Close(cctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("error closing in-memory coordination: %s", err.Error())
|
t.Logf("error closing in-memory coordination: %s", err.Error())
|
||||||
}
|
}
|
||||||
@ -2384,7 +2386,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
|
|||||||
clientID, metadata.AgentID,
|
clientID, metadata.AgentID,
|
||||||
coordinator, conn)
|
coordinator, conn)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := coordination.Close()
|
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
|
||||||
|
defer ccancel()
|
||||||
|
err := coordination.Close(cctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("error closing in-mem coordination: %s", err.Error())
|
t.Logf("error closing in-mem coordination: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
@ -277,7 +277,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
|
|||||||
select {
|
select {
|
||||||
case <-tac.ctx.Done():
|
case <-tac.ctx.Done():
|
||||||
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
|
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
|
||||||
crdErr := coordination.Close()
|
crdErr := coordination.Close(tac.gracefulCtx)
|
||||||
if crdErr != nil {
|
if crdErr != nil {
|
||||||
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
|
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|||||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||||
defer close(derpMapCh)
|
defer close(derpMapCh)
|
||||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||||
Logger: logger,
|
Logger: logger.Named("svc"),
|
||||||
CoordPtr: &coordPtr,
|
CoordPtr: &coordPtr,
|
||||||
DERPMapUpdateFrequency: time.Millisecond,
|
DERPMapUpdateFrequency: time.Millisecond,
|
||||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||||
@ -82,7 +82,8 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|||||||
|
|
||||||
fConn := newFakeTailnetConn()
|
fConn := newFakeTailnetConn()
|
||||||
|
|
||||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL,
|
||||||
|
quartz.NewReal(), &websocket.DialOptions{})
|
||||||
uut.runConnector(fConn)
|
uut.runConnector(fConn)
|
||||||
|
|
||||||
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||||
@ -108,6 +109,7 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|||||||
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
||||||
require.NotNil(t, reqDisc)
|
require.NotNil(t, reqDisc)
|
||||||
require.NotNil(t, reqDisc.Disconnect)
|
require.NotNil(t, reqDisc.Disconnect)
|
||||||
|
close(call.Resps)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
||||||
|
@ -91,7 +91,7 @@ type Coordinatee interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Coordination interface {
|
type Coordination interface {
|
||||||
io.Closer
|
Close(context.Context) error
|
||||||
Error() <-chan error
|
Error() <-chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,7 +106,10 @@ type remoteCoordination struct {
|
|||||||
respLoopDone chan struct{}
|
respLoopDone chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *remoteCoordination) Close() (retErr error) {
|
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
|
||||||
|
// waiting for the server to hang up the coordination. If the provided context expires, we stop
|
||||||
|
// waiting for the server and close the coordination stream from our end.
|
||||||
|
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
if c.closed {
|
if c.closed {
|
||||||
@ -114,6 +117,18 @@ func (c *remoteCoordination) Close() (retErr error) {
|
|||||||
}
|
}
|
||||||
c.closed = true
|
c.closed = true
|
||||||
defer func() {
|
defer func() {
|
||||||
|
// We shouldn't just close the protocol right away, because the way dRPC streams work is
|
||||||
|
// that if you close them, that could take effect immediately, even before the Disconnect
|
||||||
|
// message is processed. Coordinators are supposed to hang up on us once they get a
|
||||||
|
// Disconnect message, so we should wait around for that until the context expires.
|
||||||
|
select {
|
||||||
|
case <-c.respLoopDone:
|
||||||
|
c.logger.Debug(ctx, "responses closed after disconnect")
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
|
||||||
|
}
|
||||||
|
// forcefully close the stream
|
||||||
protoErr := c.protocol.Close()
|
protoErr := c.protocol.Close()
|
||||||
<-c.respLoopDone
|
<-c.respLoopDone
|
||||||
if retErr == nil {
|
if retErr == nil {
|
||||||
@ -240,7 +255,6 @@ type inMemoryCoordination struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
errChan chan error
|
errChan chan error
|
||||||
closed bool
|
closed bool
|
||||||
closedCh chan struct{}
|
|
||||||
respLoopDone chan struct{}
|
respLoopDone chan struct{}
|
||||||
coordinatee Coordinatee
|
coordinatee Coordinatee
|
||||||
logger slog.Logger
|
logger slog.Logger
|
||||||
@ -280,7 +294,6 @@ func NewInMemoryCoordination(
|
|||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
coordinatee: coordinatee,
|
coordinatee: coordinatee,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
closedCh: make(chan struct{}),
|
|
||||||
respLoopDone: make(chan struct{}),
|
respLoopDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -328,24 +341,15 @@ func (c *inMemoryCoordination) respLoop() {
|
|||||||
c.coordinatee.SetAllPeersLost()
|
c.coordinatee.SetAllPeersLost()
|
||||||
close(c.respLoopDone)
|
close(c.respLoopDone)
|
||||||
}()
|
}()
|
||||||
for {
|
for resp := range c.resps {
|
||||||
select {
|
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
|
||||||
case <-c.closedCh:
|
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||||
c.logger.Debug(context.Background(), "in-memory coordination closed")
|
if err != nil {
|
||||||
|
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
|
||||||
return
|
return
|
||||||
case resp, ok := <-c.resps:
|
|
||||||
if !ok {
|
|
||||||
c.logger.Debug(context.Background(), "in-memory response channel closed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
|
|
||||||
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
|
||||||
if err != nil {
|
|
||||||
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
c.logger.Debug(context.Background(), "in-memory response channel closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
|
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
|
||||||
@ -355,7 +359,10 @@ func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
|
|||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *inMemoryCoordination) Close() error {
|
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
|
||||||
|
// waiting for the server to hang up the coordination. If the provided context expires, we stop
|
||||||
|
// waiting for the server and close the coordination stream from our end.
|
||||||
|
func (c *inMemoryCoordination) Close(ctx context.Context) error {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
c.logger.Debug(context.Background(), "closing in-memory coordination")
|
c.logger.Debug(context.Background(), "closing in-memory coordination")
|
||||||
@ -364,13 +371,17 @@ func (c *inMemoryCoordination) Close() error {
|
|||||||
}
|
}
|
||||||
defer close(c.reqs)
|
defer close(c.reqs)
|
||||||
c.closed = true
|
c.closed = true
|
||||||
close(c.closedCh)
|
|
||||||
<-c.respLoopDone
|
|
||||||
select {
|
select {
|
||||||
case <-c.ctx.Done():
|
case <-ctx.Done():
|
||||||
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
||||||
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
|
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
|
||||||
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
|
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
|
||||||
|
case <-c.respLoopDone:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package tailnet_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
@ -284,7 +285,7 @@ func TestInMemoryCoordination(t *testing.T) {
|
|||||||
Times(1).Return(reqs, resps)
|
Times(1).Return(reqs, resps)
|
||||||
|
|
||||||
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
|
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
|
||||||
defer uut.Close()
|
defer uut.Close(ctx)
|
||||||
|
|
||||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||||
|
|
||||||
@ -336,16 +337,13 @@ func TestRemoteCoordination(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
|
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
|
||||||
defer uut.Close()
|
defer uut.Close(ctx)
|
||||||
|
|
||||||
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
||||||
|
|
||||||
select {
|
// Recv loop should be terminated by the server hanging up after Disconnect
|
||||||
case err := <-uut.Error():
|
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
|
||||||
require.ErrorContains(t, err, "stream terminated by sending close")
|
require.ErrorIs(t, err, io.EOF)
|
||||||
default:
|
|
||||||
// OK!
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
||||||
@ -388,7 +386,7 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
|
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
|
||||||
defer uut.Close()
|
defer uut.Close(ctx)
|
||||||
|
|
||||||
nk, err := key.NewNode().Public().MarshalBinary()
|
nk, err := key.NewNode().Public().MarshalBinary()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -411,14 +409,15 @@ func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
|
|||||||
require.Len(t, rfh.ReadyForHandshake, 1)
|
require.Len(t, rfh.ReadyForHandshake, 1)
|
||||||
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
|
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
|
||||||
|
|
||||||
require.NoError(t, uut.Close())
|
go uut.Close(ctx)
|
||||||
|
dis := testutil.RequireRecvCtx(ctx, t, reqs)
|
||||||
|
require.NotNil(t, dis)
|
||||||
|
require.NotNil(t, dis.Disconnect)
|
||||||
|
close(resps)
|
||||||
|
|
||||||
select {
|
// Recv loop should be terminated by the server hanging up after Disconnect
|
||||||
case err := <-uut.Error():
|
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
|
||||||
require.ErrorContains(t, err, "stream terminated by sending close")
|
require.ErrorIs(t, err, io.EOF)
|
||||||
default:
|
|
||||||
// OK!
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// coordinationTest tests that a coordination behaves correctly
|
// coordinationTest tests that a coordination behaves correctly
|
||||||
@ -464,13 +463,18 @@ func coordinationTest(
|
|||||||
require.Len(t, fConn.updates[0], 1)
|
require.Len(t, fConn.updates[0], 1)
|
||||||
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
|
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
|
||||||
|
|
||||||
err = uut.Close()
|
errCh := make(chan error, 1)
|
||||||
require.NoError(t, err)
|
go func() {
|
||||||
uut.Error()
|
errCh <- uut.Close(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
// When we close, it should gracefully disconnect
|
// When we close, it should gracefully disconnect
|
||||||
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
||||||
require.NotNil(t, req.Disconnect)
|
require.NotNil(t, req.Disconnect)
|
||||||
|
close(resps)
|
||||||
|
|
||||||
|
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// It should set all peers lost on the coordinatee
|
// It should set all peers lost on the coordinatee
|
||||||
require.Equal(t, 1, fConn.setAllPeersLostCalls)
|
require.Equal(t, 1, fConn.setAllPeersLostCalls)
|
||||||
|
@ -469,7 +469,9 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me
|
|||||||
|
|
||||||
coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID)
|
coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
_ = coordination.Close()
|
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||||
|
defer cancel()
|
||||||
|
_ = coordination.Close(cctx)
|
||||||
})
|
})
|
||||||
|
|
||||||
return conn
|
return conn
|
||||||
|
Reference in New Issue
Block a user