From 886dcbec843766da9e459557335ccdf5dbea7ac6 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Tue, 5 Nov 2024 13:50:10 +0400 Subject: [PATCH] chore: refactor coordination (#15343) Refactors the way clients of the Tailnet API (clients of the API, which include both workspace "agents" and "clients") interact with the API. Introduces the idea of abstract "controllers" for each of the RPCs in the API, and implements a Coordination controller by refactoring from `workspacesdk`. chore re: #14729 --- agent/agent.go | 5 +- agent/agent_test.go | 13 +- agent/agenttest/client.go | 2 - codersdk/workspacesdk/connector.go | 6 +- tailnet/controllers.go | 361 ++++++++++++++++++++++++ tailnet/controllers_test.go | 283 +++++++++++++++++++ tailnet/coordinator.go | 296 ------------------- tailnet/coordinator_test.go | 267 ------------------ tailnet/test/integration/integration.go | 3 +- 9 files changed, 658 insertions(+), 578 deletions(-) create mode 100644 tailnet/controllers.go create mode 100644 tailnet/controllers_test.go diff --git a/agent/agent.go b/agent/agent.go index cb0037dd0e..4c8497d105 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1352,7 +1352,8 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai defer close(disconnected) a.closeMutex.Unlock() - coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) + ctrl := tailnet.NewAgentCoordinationController(a.logger, network) + coordination := ctrl.New(coordinate) errCh := make(chan error, 1) go func() { @@ -1364,7 +1365,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) } return - case err := <-coordination.Error(): + case err := <-coordination.Wait(): errCh <- err } }() diff --git a/agent/agent_test.go b/agent/agent_test.go index addae8c3d8..e7fd753b8d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1918,10 +1918,8 @@ func TestAgent_UpdatedDERP(t *testing.T) { testCtx, testCtxCancel := context.WithCancel(context.Background()) t.Cleanup(testCtxCancel) clientID := uuid.New() - coordination := tailnet.NewInMemoryCoordination( - testCtx, logger, - clientID, agentID, - coordinator, conn) + ctrl := tailnet.NewSingleDestController(logger, conn, agentID) + coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, coordinator)) t.Cleanup(func() { t.Logf("closing coordination %s", name) cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) @@ -2409,10 +2407,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati testCtx, testCtxCancel := context.WithCancel(context.Background()) t.Cleanup(testCtxCancel) clientID := uuid.New() - coordination := tailnet.NewInMemoryCoordination( - testCtx, logger, - clientID, metadata.AgentID, - coordinator, conn) + ctrl := tailnet.NewSingleDestController(logger, conn, metadata.AgentID) + coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient( + logger, clientID, metadata.AgentID, coordinator)) t.Cleanup(func() { cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) defer ccancel() diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index a17f9200a9..8817b311fc 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -71,7 +71,6 @@ func NewClient(t testing.TB, t: t, logger: logger.Named("client"), agentID: agentID, - coordinator: coordinator, server: server, fakeAgentAPI: fakeAAPI, derpMapUpdates: derpMapUpdates, @@ -82,7 +81,6 @@ type Client struct { t testing.TB logger slog.Logger agentID uuid.UUID - coordinator tailnet.Coordinator server *drpcserver.Server fakeAgentAPI *FakeAgentAPI LastWorkspaceAgent func() diff --git a/codersdk/workspacesdk/connector.go b/codersdk/workspacesdk/connector.go index 780478e91a..eb14b34519 100644 --- a/codersdk/workspacesdk/connector.go +++ b/codersdk/workspacesdk/connector.go @@ -66,6 +66,7 @@ type tailnetAPIConnector struct { clock quartz.Clock dialOptions *websocket.DialOptions conn tailnetConn + coordCtrl tailnet.CoordinationController customDialFn func() (proto.DRPCTailnetClient, error) clientMu sync.RWMutex @@ -112,6 +113,7 @@ func (tac *tailnetAPIConnector) manageGracefulTimeout() { // Runs a tailnetAPIConnector using the provided connection func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) { tac.conn = conn + tac.coordCtrl = tailnet.NewSingleDestController(tac.logger, conn, tac.agentID) tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background()) go tac.manageGracefulTimeout() go func() { @@ -272,7 +274,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr)) } }() - coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID) + coordination := tac.coordCtrl.New(coord) tac.logger.Debug(tac.ctx, "serving coordinator") select { case <-tac.ctx.Done(): @@ -281,7 +283,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { if crdErr != nil { tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) } - case err = <-coordination.Error(): + case err = <-coordination.Wait(): if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) && diff --git a/tailnet/controllers.go b/tailnet/controllers.go new file mode 100644 index 0000000000..84a9bd4d79 --- /dev/null +++ b/tailnet/controllers.go @@ -0,0 +1,361 @@ +package tailnet + +import ( + "context" + "fmt" + "io" + "sync" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "storj.io/drpc" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet/proto" +) + +// A Controller connects to the tailnet control plane, and then uses the control protocols to +// program a tailnet.Conn in production (in test it could be an interface simulating the Conn). It +// delegates this task to sub-controllers responsible for the main areas of the tailnet control +// protocol: coordination, DERP map updates, resume tokens, and telemetry. +type Controller struct { + Dialer ControlProtocolDialer + CoordCtrl CoordinationController + DERPCtrl DERPController + ResumeTokenCtrl ResumeTokenController + TelemetryCtrl TelemetryController +} + +type CloserWaiter interface { + Close(context.Context) error + Wait() <-chan error +} + +// CoordinatorClient is an abstraction of the Coordinator's control protocol interface from the +// perspective of a protocol client (i.e. the Coder Agent is also a client of this interface). +type CoordinatorClient interface { + Close() error + Send(*proto.CoordinateRequest) error + Recv() (*proto.CoordinateResponse, error) +} + +// A CoordinationController accepts connections to the control plane, and handles the Coordination +// protocol on behalf of some Coordinatee (tailnet.Conn in production). This is the "glue" code +// between them. +type CoordinationController interface { + New(CoordinatorClient) CloserWaiter +} + +// DERPClient is an abstraction of the stream of DERPMap updates from the control plane. +type DERPClient interface { + Close() error + Recv() (*tailcfg.DERPMap, error) +} + +// A DERPController accepts connections to the control plane, and handles the DERPMap updates +// delivered over them by programming the data plane (tailnet.Conn or some test interface). +type DERPController interface { + New(DERPClient) CloserWaiter +} + +type ResumeTokenClient interface { + RefreshResumeToken(ctx context.Context, in *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) +} + +type ResumeTokenController interface { + New(ResumeTokenClient) CloserWaiter + Token() (string, bool) +} + +type TelemetryClient interface { + PostTelemetry(ctx context.Context, in *proto.TelemetryRequest) (*proto.TelemetryResponse, error) +} + +type TelemetryController interface { + New(TelemetryClient) +} + +// ControlProtocolClients represents an abstract interface to the tailnet control plane via a set +// of protocol clients. The Closer should close all the clients (e.g. by closing the underlying +// connection). +type ControlProtocolClients struct { + Closer io.Closer + Coordinator CoordinatorClient + DERP DERPClient + ResumeToken ResumeTokenClient + Telemetry TelemetryClient +} + +type ControlProtocolDialer interface { + // Dial connects to the tailnet control plane and returns clients for the different control + // sub-protocols (coordination, DERP maps, resume tokens, and telemetry). If the + // ResumeTokenController is not nil, the dialer should query for a resume token and use it to + // dial, if available. + Dial(ctx context.Context, r ResumeTokenController) (ControlProtocolClients, error) +} + +// basicCoordinationController handles the basic coordination operations common to all types of +// tailnet consumers: +// +// 1. sending local node updates to the Coordinator +// 2. receiving peer node updates and programming them into the Coordinatee (e.g. tailnet.Conn) +// 3. (optionally) sending ReadyToHandshake acknowledgements for peer updates. +type basicCoordinationController struct { + logger slog.Logger + coordinatee Coordinatee + sendAcks bool +} + +func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter { + b := &basicCoordination{ + logger: c.logger, + errChan: make(chan error, 1), + coordinatee: c.coordinatee, + client: client, + respLoopDone: make(chan struct{}), + sendAcks: c.sendAcks, + } + + c.coordinatee.SetNodeCallback(func(node *Node) { + pn, err := NodeToProto(node) + if err != nil { + b.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + b.sendErr(err) + return + } + b.Lock() + defer b.Unlock() + if b.closed { + b.logger.Debug(context.Background(), "ignored node update because coordination is closed") + return + } + err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) + if err != nil { + b.sendErr(xerrors.Errorf("write: %w", err)) + } + }) + go b.respLoop() + + return b +} + +type basicCoordination struct { + sync.Mutex + closed bool + errChan chan error + coordinatee Coordinatee + logger slog.Logger + client CoordinatorClient + respLoopDone chan struct{} + sendAcks bool +} + +func (c *basicCoordination) Close(ctx context.Context) (retErr error) { + c.Lock() + defer c.Unlock() + if c.closed { + return nil + } + c.closed = true + defer func() { + // We shouldn't just close the protocol right away, because the way dRPC streams work is + // that if you close them, that could take effect immediately, even before the Disconnect + // message is processed. Coordinators are supposed to hang up on us once they get a + // Disconnect message, so we should wait around for that until the context expires. + select { + case <-c.respLoopDone: + c.logger.Debug(ctx, "responses closed after disconnect") + return + case <-ctx.Done(): + c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") + } + // forcefully close the stream + protoErr := c.client.Close() + <-c.respLoopDone + if retErr == nil { + retErr = protoErr + } + }() + err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + if err != nil && !xerrors.Is(err, io.EOF) { + // Coordinator RPC hangs up when it gets disconnect, so EOF is expected. + return xerrors.Errorf("send disconnect: %w", err) + } + c.logger.Debug(context.Background(), "sent disconnect") + return nil +} + +func (c *basicCoordination) Wait() <-chan error { + return c.errChan +} + +func (c *basicCoordination) sendErr(err error) { + select { + case c.errChan <- err: + default: + } +} + +func (c *basicCoordination) respLoop() { + defer func() { + cErr := c.client.Close() + if cErr != nil { + c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr)) + } + c.coordinatee.SetAllPeersLost() + close(c.respLoopDone) + }() + for { + resp, err := c.client.Recv() + if err != nil { + c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err)) + c.sendErr(xerrors.Errorf("read: %w", err)) + return + } + + err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err)) + c.sendErr(xerrors.Errorf("update peers: %w", err)) + return + } + + // Only send ReadyForHandshake acks from peers without a target. + if c.sendAcks { + // Send an ack back for all received peers. This could + // potentially be smarter to only send an ACK once per client, + // but there's nothing currently stopping clients from reusing + // IDs. + rfh := []*proto.CoordinateRequest_ReadyForHandshake{} + for _, peer := range resp.GetPeerUpdates() { + if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE { + continue + } + + rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id}) + } + if len(rfh) > 0 { + err := c.client.Send(&proto.CoordinateRequest{ + ReadyForHandshake: rfh, + }) + if err != nil { + c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err)) + c.sendErr(xerrors.Errorf("send: %w", err)) + return + } + } + } + } +} + +type singleDestController struct { + *basicCoordinationController + dest uuid.UUID +} + +// NewSingleDestController creates a CoordinationController for Coder clients that connect to a +// single tunnel destination, e.g. `coder ssh`, which connects to a single workspace Agent. +func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController { + coordinatee.SetTunnelDestination(dest) + return &singleDestController{ + basicCoordinationController: &basicCoordinationController{ + logger: logger, + coordinatee: coordinatee, + sendAcks: false, + }, + dest: dest, + } +} + +func (c *singleDestController) New(client CoordinatorClient) CloserWaiter { + // nolint: forcetypeassert + b := c.basicCoordinationController.New(client).(*basicCoordination) + err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}}) + if err != nil { + b.sendErr(err) + } + return b +} + +// NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never +// create tunnels and always send ReadyToHandshake acknowledgements. +func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController { + return &basicCoordinationController{ + logger: logger, + coordinatee: coordinatee, + sendAcks: true, + } +} + +type inMemoryCoordClient struct { + sync.Mutex + ctx context.Context + cancel context.CancelFunc + closed bool + logger slog.Logger + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest +} + +func (c *inMemoryCoordClient) Close() error { + c.cancel() + c.Lock() + defer c.Unlock() + if c.closed { + return nil + } + c.closed = true + close(c.reqs) + return nil +} + +func (c *inMemoryCoordClient) Send(request *proto.CoordinateRequest) error { + c.Lock() + defer c.Unlock() + if c.closed { + return drpc.ClosedError.New("in-memory coordinator client closed") + } + select { + case c.reqs <- request: + return nil + case <-c.ctx.Done(): + return drpc.ClosedError.New("in-memory coordinator client closed") + } +} + +func (c *inMemoryCoordClient) Recv() (*proto.CoordinateResponse, error) { + select { + case resp, ok := <-c.resps: + if ok { + return resp, nil + } + // response from Coordinator was closed, so close the send direction as well, so that the + // Coordinator won't be waiting for us while shutting down. + _ = c.Close() + return nil, io.EOF + case <-c.ctx.Done(): + return nil, drpc.ClosedError.New("in-memory coord client closed") + } +} + +// NewInMemoryCoordinatorClient creates a coordination client that uses channels to connect to a +// local Coordinator. (The typical alternative is a DRPC-based client.) +func NewInMemoryCoordinatorClient( + logger slog.Logger, + clientID, agentID uuid.UUID, + coordinator Coordinator, +) CoordinatorClient { + logger = logger.With(slog.F("agent_id", agentID), slog.F("client_id", clientID)) + auth := ClientCoordinateeAuth{AgentID: agentID} + c := &inMemoryCoordClient{logger: logger} + c.ctx, c.cancel = context.WithCancel(context.Background()) + + // use the background context since we will depend exclusively on closing the req channel to + // tell the coordinator we are done. + c.reqs, c.resps = coordinator.Coordinate(context.Background(), + clientID, fmt.Sprintf("inmemory%s", clientID), + auth, + ) + return c +} diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go new file mode 100644 index 0000000000..2e3098f80f --- /dev/null +++ b/tailnet/controllers_test.go @@ -0,0 +1,283 @@ +package tailnet_test + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" +) + +func TestInMemoryCoordination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). + Times(1).Return(reqs, resps) + + ctrl := tailnet.NewSingleDestController(logger, fConn, agentID) + uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, mCoord)) + defer uut.Close(ctx) + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + // Recv loop should be terminated by the server hanging up after Disconnect + err := testutil.RequireRecvCtx(ctx, t, uut.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestSingleDestController(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). + Times(1).Return(reqs, resps) + + var coord tailnet.Coordinator = mCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger.Named("svc"), + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Hour, + DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, + ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), + }) + require.NoError(t, err) + sC, cC := net.Pipe() + + serveErr := make(chan error, 1) + go func() { + err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientCoordinateeAuth{ + AgentID: agentID, + }, + }) + serveErr <- err + }() + + client, err := tailnet.NewDRPCClient(cC, logger) + require.NoError(t, err) + protocol, err := client.Coordinate(ctx) + require.NoError(t, err) + + ctrl := tailnet.NewSingleDestController(logger.Named("coordination"), fConn, agentID) + uut := ctrl.New(protocol) + defer uut.Close(ctx) + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). + Times(1).Return(reqs, resps) + + var coord tailnet.Coordinator = mCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger.Named("svc"), + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Hour, + DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, + ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), + }) + require.NoError(t, err) + sC, cC := net.Pipe() + + serveErr := make(chan error, 1) + go func() { + err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientCoordinateeAuth{ + AgentID: agentID, + }, + }) + serveErr <- err + }() + + client, err := tailnet.NewDRPCClient(cC, logger) + require.NoError(t, err) + protocol, err := client.Coordinate(ctx) + require.NoError(t, err) + + ctrl := tailnet.NewAgentCoordinationController(logger.Named("coordination"), fConn) + uut := ctrl.New(protocol) + defer uut.Close(ctx) + + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{ + PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{ + Id: clientID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{ + Id: 3, + Key: nk, + Disco: string(dk), + }, + }}, + }) + + rfh := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, rfh.ReadyForHandshake) + require.Len(t, rfh.ReadyForHandshake, 1) + require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id) + + go uut.Close(ctx) + dis := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, dis) + require.NotNil(t, dis.Disconnect) + close(resps) + + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +// coordinationTest tests that a coordination behaves correctly +func coordinationTest( + ctx context.Context, t *testing.T, + uut tailnet.CloserWaiter, fConn *fakeCoordinatee, + reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse, + agentID uuid.UUID, +) { + // It should add the tunnel, since we configured as a client + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) + + // when we call the callback, it should send a node update + require.NotNil(t, fConn.callback) + fConn.callback(&tailnet.Node{PreferredDERP: 1}) + + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp()) + + // When we send a peer update, it should update the coordinatee + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + updates := []*proto.CoordinateResponse_PeerUpdate{ + { + Id: agentID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{ + Id: 2, + Key: nk, + Disco: string(dk), + }, + }, + } + testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates}) + require.Eventually(t, func() bool { + fConn.Lock() + defer fConn.Unlock() + return len(fConn.updates) > 0 + }, testutil.WaitShort, testutil.IntervalFast) + require.Len(t, fConn.updates[0], 1) + require.Equal(t, agentID[:], fConn.updates[0][0].Id) + + errCh := make(chan error, 1) + go func() { + errCh <- uut.Close(ctx) + }() + + // When we close, it should gracefully disconnect + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, req.Disconnect) + close(resps) + + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.NoError(t, err) + + // It should set all peers lost on the coordinatee + require.Equal(t, 1, fConn.setAllPeersLostCalls) +} + +type fakeCoordinatee struct { + sync.Mutex + callback func(*tailnet.Node) + updates [][]*proto.CoordinateResponse_PeerUpdate + setAllPeersLostCalls int + tunnelDestinations map[uuid.UUID]struct{} +} + +func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { + f.Lock() + defer f.Unlock() + f.updates = append(f.updates, updates) + return nil +} + +func (f *fakeCoordinatee) SetAllPeersLost() { + f.Lock() + defer f.Unlock() + f.setAllPeersLostCalls++ +} + +func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) { + f.Lock() + defer f.Unlock() + + if f.tunnelDestinations == nil { + f.tunnelDestinations = map[uuid.UUID]struct{}{} + } + f.tunnelDestinations[id] = struct{}{} +} + +func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) { + f.Lock() + defer f.Unlock() + f.callback = callback +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index b059259895..d883ca1b4c 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -90,302 +90,6 @@ type Coordinatee interface { SetTunnelDestination(id uuid.UUID) } -type Coordination interface { - Close(context.Context) error - Error() <-chan error -} - -type remoteCoordination struct { - sync.Mutex - closed bool - errChan chan error - coordinatee Coordinatee - tgt uuid.UUID - logger slog.Logger - protocol proto.DRPCTailnet_CoordinateClient - respLoopDone chan struct{} -} - -// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and -// waiting for the server to hang up the coordination. If the provided context expires, we stop -// waiting for the server and close the coordination stream from our end. -func (c *remoteCoordination) Close(ctx context.Context) (retErr error) { - c.Lock() - defer c.Unlock() - if c.closed { - return nil - } - c.closed = true - defer func() { - // We shouldn't just close the protocol right away, because the way dRPC streams work is - // that if you close them, that could take effect immediately, even before the Disconnect - // message is processed. Coordinators are supposed to hang up on us once they get a - // Disconnect message, so we should wait around for that until the context expires. - select { - case <-c.respLoopDone: - c.logger.Debug(ctx, "responses closed after disconnect") - return - case <-ctx.Done(): - c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") - } - // forcefully close the stream - protoErr := c.protocol.Close() - <-c.respLoopDone - if retErr == nil { - retErr = protoErr - } - }() - err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) - if err != nil && !xerrors.Is(err, io.EOF) { - // Coordinator RPC hangs up when it gets disconnect, so EOF is expected. - return xerrors.Errorf("send disconnect: %w", err) - } - c.logger.Debug(context.Background(), "sent disconnect") - return nil -} - -func (c *remoteCoordination) Error() <-chan error { - return c.errChan -} - -func (c *remoteCoordination) sendErr(err error) { - select { - case c.errChan <- err: - default: - } -} - -func (c *remoteCoordination) respLoop() { - defer func() { - c.coordinatee.SetAllPeersLost() - close(c.respLoopDone) - }() - for { - resp, err := c.protocol.Recv() - if err != nil { - c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err)) - c.sendErr(xerrors.Errorf("read: %w", err)) - return - } - - err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) - if err != nil { - c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err)) - c.sendErr(xerrors.Errorf("update peers: %w", err)) - return - } - - // Only send acks from peers without a target. - if c.tgt == uuid.Nil { - // Send an ack back for all received peers. This could - // potentially be smarter to only send an ACK once per client, - // but there's nothing currently stopping clients from reusing - // IDs. - rfh := []*proto.CoordinateRequest_ReadyForHandshake{} - for _, peer := range resp.GetPeerUpdates() { - if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE { - continue - } - - rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id}) - } - if len(rfh) > 0 { - err := c.protocol.Send(&proto.CoordinateRequest{ - ReadyForHandshake: rfh, - }) - if err != nil { - c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err)) - c.sendErr(xerrors.Errorf("send: %w", err)) - return - } - } - } - } -} - -// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a -// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as -// a client---agents should NOT set this!). -func NewRemoteCoordination(logger slog.Logger, - protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee, - tunnelTarget uuid.UUID, -) Coordination { - c := &remoteCoordination{ - errChan: make(chan error, 1), - coordinatee: coordinatee, - tgt: tunnelTarget, - logger: logger, - protocol: protocol, - respLoopDone: make(chan struct{}), - } - if tunnelTarget != uuid.Nil { - c.coordinatee.SetTunnelDestination(tunnelTarget) - c.Lock() - err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}}) - c.Unlock() - if err != nil { - c.sendErr(err) - } - } - - coordinatee.SetNodeCallback(func(node *Node) { - pn, err := NodeToProto(node) - if err != nil { - c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) - c.sendErr(err) - return - } - c.Lock() - defer c.Unlock() - if c.closed { - c.logger.Debug(context.Background(), "ignored node update because coordination is closed") - return - } - err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) - if err != nil { - c.sendErr(xerrors.Errorf("write: %w", err)) - } - }) - go c.respLoop() - return c -} - -type inMemoryCoordination struct { - sync.Mutex - ctx context.Context - errChan chan error - closed bool - respLoopDone chan struct{} - coordinatee Coordinatee - logger slog.Logger - resps <-chan *proto.CoordinateResponse - reqs chan<- *proto.CoordinateRequest -} - -func (c *inMemoryCoordination) sendErr(err error) { - select { - case c.errChan <- err: - default: - } -} - -func (c *inMemoryCoordination) Error() <-chan error { - return c.errChan -} - -// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing -// or local clients. Set ClientID to uuid.Nil for an agent. -func NewInMemoryCoordination( - ctx context.Context, logger slog.Logger, - clientID, agentID uuid.UUID, - coordinator Coordinator, coordinatee Coordinatee, -) Coordination { - thisID := agentID - logger = logger.With(slog.F("agent_id", agentID)) - var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID} - if clientID != uuid.Nil { - // this is a client connection - auth = ClientCoordinateeAuth{AgentID: agentID} - logger = logger.With(slog.F("client_id", clientID)) - thisID = clientID - } - c := &inMemoryCoordination{ - ctx: ctx, - errChan: make(chan error, 1), - coordinatee: coordinatee, - logger: logger, - respLoopDone: make(chan struct{}), - } - - // use the background context since we will depend exclusively on closing the req channel to - // tell the coordinator we are done. - c.reqs, c.resps = coordinator.Coordinate(context.Background(), - thisID, fmt.Sprintf("inmemory%s", thisID), - auth, - ) - go c.respLoop() - if agentID != uuid.Nil { - select { - case <-ctx.Done(): - c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err())) - return c - case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}: - // OK! - } - } - coordinatee.SetNodeCallback(func(n *Node) { - pn, err := NodeToProto(n) - if err != nil { - c.logger.Critical(ctx, "failed to convert node", slog.Error(err)) - c.sendErr(err) - return - } - c.Lock() - defer c.Unlock() - if c.closed { - return - } - select { - case <-ctx.Done(): - c.logger.Info(ctx, "context expired before sending node update") - return - case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}: - c.logger.Debug(ctx, "sent node in-memory to coordinator") - } - }) - return c -} - -func (c *inMemoryCoordination) respLoop() { - defer func() { - c.coordinatee.SetAllPeersLost() - close(c.respLoopDone) - }() - for resp := range c.resps { - c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) - err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) - if err != nil { - c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) - return - } - } - c.logger.Debug(context.Background(), "in-memory response channel closed") -} - -func (*inMemoryCoordination) AwaitAck() <-chan struct{} { - // This is only used for tests, so just return a closed channel. - ch := make(chan struct{}) - close(ch) - return ch -} - -// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and -// waiting for the server to hang up the coordination. If the provided context expires, we stop -// waiting for the server and close the coordination stream from our end. -func (c *inMemoryCoordination) Close(ctx context.Context) error { - c.Lock() - defer c.Unlock() - c.logger.Debug(context.Background(), "closing in-memory coordination") - if c.closed { - return nil - } - defer close(c.reqs) - c.closed = true - select { - case <-ctx.Done(): - return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) - case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}: - c.logger.Debug(context.Background(), "sent graceful disconnect in-memory") - } - - select { - case <-ctx.Done(): - return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err()) - case <-c.respLoopDone: - return nil - } -} - const LoggerName = "coord" var ( diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index b3a803cd6a..67cf476849 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -2,19 +2,11 @@ package tailnet_test import ( "context" - "io" - "net" "net/netip" - "sync" - "sync/atomic" "testing" - "time" "github.com/google/uuid" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "tailscale.com/tailcfg" - "tailscale.com/types/key" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -271,265 +263,6 @@ func TestCoordinator_MultiAgent_CoordClose(t *testing.T) { ma1.RequireEventuallyClosed(ctx) } -func TestInMemoryCoordination(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clientID := uuid.UUID{1} - agentID := uuid.UUID{2} - mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) - fConn := &fakeCoordinatee{} - - reqs := make(chan *proto.CoordinateRequest, 100) - resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). - Times(1).Return(reqs, resps) - - uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) - defer uut.Close(ctx) - - coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) - - select { - case err := <-uut.Error(): - require.NoError(t, err) - default: - // OK! - } -} - -func TestRemoteCoordination(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clientID := uuid.UUID{1} - agentID := uuid.UUID{2} - mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) - fConn := &fakeCoordinatee{} - - reqs := make(chan *proto.CoordinateRequest, 100) - resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). - Times(1).Return(reqs, resps) - - var coord tailnet.Coordinator = mCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger.Named("svc"), - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Hour, - DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, - ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), - }) - require.NoError(t, err) - sC, cC := net.Pipe() - - serveErr := make(chan error, 1) - go func() { - err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ - Name: "client", - ID: clientID, - Auth: tailnet.ClientCoordinateeAuth{ - AgentID: agentID, - }, - }) - serveErr <- err - }() - - client, err := tailnet.NewDRPCClient(cC, logger) - require.NoError(t, err) - protocol, err := client.Coordinate(ctx) - require.NoError(t, err) - - uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID) - defer uut.Close(ctx) - - coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) - - // Recv loop should be terminated by the server hanging up after Disconnect - err = testutil.RequireRecvCtx(ctx, t, uut.Error()) - require.ErrorIs(t, err, io.EOF) -} - -func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clientID := uuid.UUID{1} - agentID := uuid.UUID{2} - mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) - fConn := &fakeCoordinatee{} - - reqs := make(chan *proto.CoordinateRequest, 100) - resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). - Times(1).Return(reqs, resps) - - var coord tailnet.Coordinator = mCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger.Named("svc"), - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Hour, - DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, - ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), - }) - require.NoError(t, err) - sC, cC := net.Pipe() - - serveErr := make(chan error, 1) - go func() { - err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ - Name: "client", - ID: clientID, - Auth: tailnet.ClientCoordinateeAuth{ - AgentID: agentID, - }, - }) - serveErr <- err - }() - - client, err := tailnet.NewDRPCClient(cC, logger) - require.NoError(t, err) - protocol, err := client.Coordinate(ctx) - require.NoError(t, err) - - uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{}) - defer uut.Close(ctx) - - nk, err := key.NewNode().Public().MarshalBinary() - require.NoError(t, err) - dk, err := key.NewDisco().Public().MarshalText() - require.NoError(t, err) - testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{ - PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{ - Id: clientID[:], - Kind: proto.CoordinateResponse_PeerUpdate_NODE, - Node: &proto.Node{ - Id: 3, - Key: nk, - Disco: string(dk), - }, - }}, - }) - - rfh := testutil.RequireRecvCtx(ctx, t, reqs) - require.NotNil(t, rfh.ReadyForHandshake) - require.Len(t, rfh.ReadyForHandshake, 1) - require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id) - - go uut.Close(ctx) - dis := testutil.RequireRecvCtx(ctx, t, reqs) - require.NotNil(t, dis) - require.NotNil(t, dis.Disconnect) - close(resps) - - // Recv loop should be terminated by the server hanging up after Disconnect - err = testutil.RequireRecvCtx(ctx, t, uut.Error()) - require.ErrorIs(t, err, io.EOF) -} - -// coordinationTest tests that a coordination behaves correctly -func coordinationTest( - ctx context.Context, t *testing.T, - uut tailnet.Coordination, fConn *fakeCoordinatee, - reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse, - agentID uuid.UUID, -) { - // It should add the tunnel, since we configured as a client - req := testutil.RequireRecvCtx(ctx, t, reqs) - require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) - - // when we call the callback, it should send a node update - require.NotNil(t, fConn.callback) - fConn.callback(&tailnet.Node{PreferredDERP: 1}) - - req = testutil.RequireRecvCtx(ctx, t, reqs) - require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp()) - - // When we send a peer update, it should update the coordinatee - nk, err := key.NewNode().Public().MarshalBinary() - require.NoError(t, err) - dk, err := key.NewDisco().Public().MarshalText() - require.NoError(t, err) - updates := []*proto.CoordinateResponse_PeerUpdate{ - { - Id: agentID[:], - Kind: proto.CoordinateResponse_PeerUpdate_NODE, - Node: &proto.Node{ - Id: 2, - Key: nk, - Disco: string(dk), - }, - }, - } - testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates}) - require.Eventually(t, func() bool { - fConn.Lock() - defer fConn.Unlock() - return len(fConn.updates) > 0 - }, testutil.WaitShort, testutil.IntervalFast) - require.Len(t, fConn.updates[0], 1) - require.Equal(t, agentID[:], fConn.updates[0][0].Id) - - errCh := make(chan error, 1) - go func() { - errCh <- uut.Close(ctx) - }() - - // When we close, it should gracefully disconnect - req = testutil.RequireRecvCtx(ctx, t, reqs) - require.NotNil(t, req.Disconnect) - close(resps) - - err = testutil.RequireRecvCtx(ctx, t, errCh) - require.NoError(t, err) - - // It should set all peers lost on the coordinatee - require.Equal(t, 1, fConn.setAllPeersLostCalls) -} - -type fakeCoordinatee struct { - sync.Mutex - callback func(*tailnet.Node) - updates [][]*proto.CoordinateResponse_PeerUpdate - setAllPeersLostCalls int - tunnelDestinations map[uuid.UUID]struct{} -} - -func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { - f.Lock() - defer f.Unlock() - f.updates = append(f.updates, updates) - return nil -} - -func (f *fakeCoordinatee) SetAllPeersLost() { - f.Lock() - defer f.Unlock() - f.setAllPeersLostCalls++ -} - -func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) { - f.Lock() - defer f.Unlock() - - if f.tunnelDestinations == nil { - f.tunnelDestinations = map[uuid.UUID]struct{}{} - } - f.tunnelDestinations[id] = struct{}{} -} - -func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) { - f.Lock() - defer f.Unlock() - f.callback = callback -} - // TestCoordinatorPropogatedPeerContext tests that the context for a specific peer // is propogated through to the `Authorizeā€œ method of the coordinatee auth func TestCoordinatorPropogatedPeerContext(t *testing.T) { diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index ff38aec98b..232e7ab027 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -467,7 +467,8 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me _ = conn.Close() }) - coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID) + ctrl := tailnet.NewSingleDestController(logger, conn, peer.ID) + coordination := ctrl.New(coord) t.Cleanup(func() { cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel()