diff --git a/agent/agent.go b/agent/agent.go index cb0037dd0e..4c8497d105 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -1352,7 +1352,8 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai defer close(disconnected) a.closeMutex.Unlock() - coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) + ctrl := tailnet.NewAgentCoordinationController(a.logger, network) + coordination := ctrl.New(coordinate) errCh := make(chan error, 1) go func() { @@ -1364,7 +1365,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) } return - case err := <-coordination.Error(): + case err := <-coordination.Wait(): errCh <- err } }() diff --git a/agent/agent_test.go b/agent/agent_test.go index addae8c3d8..e7fd753b8d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1918,10 +1918,8 @@ func TestAgent_UpdatedDERP(t *testing.T) { testCtx, testCtxCancel := context.WithCancel(context.Background()) t.Cleanup(testCtxCancel) clientID := uuid.New() - coordination := tailnet.NewInMemoryCoordination( - testCtx, logger, - clientID, agentID, - coordinator, conn) + ctrl := tailnet.NewSingleDestController(logger, conn, agentID) + coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, coordinator)) t.Cleanup(func() { t.Logf("closing coordination %s", name) cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) @@ -2409,10 +2407,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati testCtx, testCtxCancel := context.WithCancel(context.Background()) t.Cleanup(testCtxCancel) clientID := uuid.New() - coordination := tailnet.NewInMemoryCoordination( - testCtx, logger, - clientID, metadata.AgentID, - coordinator, conn) + ctrl := tailnet.NewSingleDestController(logger, conn, metadata.AgentID) + coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient( + logger, clientID, metadata.AgentID, coordinator)) t.Cleanup(func() { cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) defer ccancel() diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index a17f9200a9..8817b311fc 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -71,7 +71,6 @@ func NewClient(t testing.TB, t: t, logger: logger.Named("client"), agentID: agentID, - coordinator: coordinator, server: server, fakeAgentAPI: fakeAAPI, derpMapUpdates: derpMapUpdates, @@ -82,7 +81,6 @@ type Client struct { t testing.TB logger slog.Logger agentID uuid.UUID - coordinator tailnet.Coordinator server *drpcserver.Server fakeAgentAPI *FakeAgentAPI LastWorkspaceAgent func() diff --git a/codersdk/workspacesdk/connector.go b/codersdk/workspacesdk/connector.go index 780478e91a..eb14b34519 100644 --- a/codersdk/workspacesdk/connector.go +++ b/codersdk/workspacesdk/connector.go @@ -66,6 +66,7 @@ type tailnetAPIConnector struct { clock quartz.Clock dialOptions *websocket.DialOptions conn tailnetConn + coordCtrl tailnet.CoordinationController customDialFn func() (proto.DRPCTailnetClient, error) clientMu sync.RWMutex @@ -112,6 +113,7 @@ func (tac *tailnetAPIConnector) manageGracefulTimeout() { // Runs a tailnetAPIConnector using the provided connection func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) { tac.conn = conn + tac.coordCtrl = tailnet.NewSingleDestController(tac.logger, conn, tac.agentID) tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background()) go tac.manageGracefulTimeout() go func() { @@ -272,7 +274,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr)) } }() - coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID) + coordination := tac.coordCtrl.New(coord) tac.logger.Debug(tac.ctx, "serving coordinator") select { case <-tac.ctx.Done(): @@ -281,7 +283,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) { if crdErr != nil { tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) } - case err = <-coordination.Error(): + case err = <-coordination.Wait(): if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) && diff --git a/tailnet/controllers.go b/tailnet/controllers.go new file mode 100644 index 0000000000..84a9bd4d79 --- /dev/null +++ b/tailnet/controllers.go @@ -0,0 +1,361 @@ +package tailnet + +import ( + "context" + "fmt" + "io" + "sync" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "storj.io/drpc" + "tailscale.com/tailcfg" + + "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet/proto" +) + +// A Controller connects to the tailnet control plane, and then uses the control protocols to +// program a tailnet.Conn in production (in test it could be an interface simulating the Conn). It +// delegates this task to sub-controllers responsible for the main areas of the tailnet control +// protocol: coordination, DERP map updates, resume tokens, and telemetry. +type Controller struct { + Dialer ControlProtocolDialer + CoordCtrl CoordinationController + DERPCtrl DERPController + ResumeTokenCtrl ResumeTokenController + TelemetryCtrl TelemetryController +} + +type CloserWaiter interface { + Close(context.Context) error + Wait() <-chan error +} + +// CoordinatorClient is an abstraction of the Coordinator's control protocol interface from the +// perspective of a protocol client (i.e. the Coder Agent is also a client of this interface). +type CoordinatorClient interface { + Close() error + Send(*proto.CoordinateRequest) error + Recv() (*proto.CoordinateResponse, error) +} + +// A CoordinationController accepts connections to the control plane, and handles the Coordination +// protocol on behalf of some Coordinatee (tailnet.Conn in production). This is the "glue" code +// between them. +type CoordinationController interface { + New(CoordinatorClient) CloserWaiter +} + +// DERPClient is an abstraction of the stream of DERPMap updates from the control plane. +type DERPClient interface { + Close() error + Recv() (*tailcfg.DERPMap, error) +} + +// A DERPController accepts connections to the control plane, and handles the DERPMap updates +// delivered over them by programming the data plane (tailnet.Conn or some test interface). +type DERPController interface { + New(DERPClient) CloserWaiter +} + +type ResumeTokenClient interface { + RefreshResumeToken(ctx context.Context, in *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) +} + +type ResumeTokenController interface { + New(ResumeTokenClient) CloserWaiter + Token() (string, bool) +} + +type TelemetryClient interface { + PostTelemetry(ctx context.Context, in *proto.TelemetryRequest) (*proto.TelemetryResponse, error) +} + +type TelemetryController interface { + New(TelemetryClient) +} + +// ControlProtocolClients represents an abstract interface to the tailnet control plane via a set +// of protocol clients. The Closer should close all the clients (e.g. by closing the underlying +// connection). +type ControlProtocolClients struct { + Closer io.Closer + Coordinator CoordinatorClient + DERP DERPClient + ResumeToken ResumeTokenClient + Telemetry TelemetryClient +} + +type ControlProtocolDialer interface { + // Dial connects to the tailnet control plane and returns clients for the different control + // sub-protocols (coordination, DERP maps, resume tokens, and telemetry). If the + // ResumeTokenController is not nil, the dialer should query for a resume token and use it to + // dial, if available. + Dial(ctx context.Context, r ResumeTokenController) (ControlProtocolClients, error) +} + +// basicCoordinationController handles the basic coordination operations common to all types of +// tailnet consumers: +// +// 1. sending local node updates to the Coordinator +// 2. receiving peer node updates and programming them into the Coordinatee (e.g. tailnet.Conn) +// 3. (optionally) sending ReadyToHandshake acknowledgements for peer updates. +type basicCoordinationController struct { + logger slog.Logger + coordinatee Coordinatee + sendAcks bool +} + +func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter { + b := &basicCoordination{ + logger: c.logger, + errChan: make(chan error, 1), + coordinatee: c.coordinatee, + client: client, + respLoopDone: make(chan struct{}), + sendAcks: c.sendAcks, + } + + c.coordinatee.SetNodeCallback(func(node *Node) { + pn, err := NodeToProto(node) + if err != nil { + b.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + b.sendErr(err) + return + } + b.Lock() + defer b.Unlock() + if b.closed { + b.logger.Debug(context.Background(), "ignored node update because coordination is closed") + return + } + err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) + if err != nil { + b.sendErr(xerrors.Errorf("write: %w", err)) + } + }) + go b.respLoop() + + return b +} + +type basicCoordination struct { + sync.Mutex + closed bool + errChan chan error + coordinatee Coordinatee + logger slog.Logger + client CoordinatorClient + respLoopDone chan struct{} + sendAcks bool +} + +func (c *basicCoordination) Close(ctx context.Context) (retErr error) { + c.Lock() + defer c.Unlock() + if c.closed { + return nil + } + c.closed = true + defer func() { + // We shouldn't just close the protocol right away, because the way dRPC streams work is + // that if you close them, that could take effect immediately, even before the Disconnect + // message is processed. Coordinators are supposed to hang up on us once they get a + // Disconnect message, so we should wait around for that until the context expires. + select { + case <-c.respLoopDone: + c.logger.Debug(ctx, "responses closed after disconnect") + return + case <-ctx.Done(): + c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") + } + // forcefully close the stream + protoErr := c.client.Close() + <-c.respLoopDone + if retErr == nil { + retErr = protoErr + } + }() + err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) + if err != nil && !xerrors.Is(err, io.EOF) { + // Coordinator RPC hangs up when it gets disconnect, so EOF is expected. + return xerrors.Errorf("send disconnect: %w", err) + } + c.logger.Debug(context.Background(), "sent disconnect") + return nil +} + +func (c *basicCoordination) Wait() <-chan error { + return c.errChan +} + +func (c *basicCoordination) sendErr(err error) { + select { + case c.errChan <- err: + default: + } +} + +func (c *basicCoordination) respLoop() { + defer func() { + cErr := c.client.Close() + if cErr != nil { + c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr)) + } + c.coordinatee.SetAllPeersLost() + close(c.respLoopDone) + }() + for { + resp, err := c.client.Recv() + if err != nil { + c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err)) + c.sendErr(xerrors.Errorf("read: %w", err)) + return + } + + err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) + if err != nil { + c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err)) + c.sendErr(xerrors.Errorf("update peers: %w", err)) + return + } + + // Only send ReadyForHandshake acks from peers without a target. + if c.sendAcks { + // Send an ack back for all received peers. This could + // potentially be smarter to only send an ACK once per client, + // but there's nothing currently stopping clients from reusing + // IDs. + rfh := []*proto.CoordinateRequest_ReadyForHandshake{} + for _, peer := range resp.GetPeerUpdates() { + if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE { + continue + } + + rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id}) + } + if len(rfh) > 0 { + err := c.client.Send(&proto.CoordinateRequest{ + ReadyForHandshake: rfh, + }) + if err != nil { + c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err)) + c.sendErr(xerrors.Errorf("send: %w", err)) + return + } + } + } + } +} + +type singleDestController struct { + *basicCoordinationController + dest uuid.UUID +} + +// NewSingleDestController creates a CoordinationController for Coder clients that connect to a +// single tunnel destination, e.g. `coder ssh`, which connects to a single workspace Agent. +func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController { + coordinatee.SetTunnelDestination(dest) + return &singleDestController{ + basicCoordinationController: &basicCoordinationController{ + logger: logger, + coordinatee: coordinatee, + sendAcks: false, + }, + dest: dest, + } +} + +func (c *singleDestController) New(client CoordinatorClient) CloserWaiter { + // nolint: forcetypeassert + b := c.basicCoordinationController.New(client).(*basicCoordination) + err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}}) + if err != nil { + b.sendErr(err) + } + return b +} + +// NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never +// create tunnels and always send ReadyToHandshake acknowledgements. +func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController { + return &basicCoordinationController{ + logger: logger, + coordinatee: coordinatee, + sendAcks: true, + } +} + +type inMemoryCoordClient struct { + sync.Mutex + ctx context.Context + cancel context.CancelFunc + closed bool + logger slog.Logger + resps <-chan *proto.CoordinateResponse + reqs chan<- *proto.CoordinateRequest +} + +func (c *inMemoryCoordClient) Close() error { + c.cancel() + c.Lock() + defer c.Unlock() + if c.closed { + return nil + } + c.closed = true + close(c.reqs) + return nil +} + +func (c *inMemoryCoordClient) Send(request *proto.CoordinateRequest) error { + c.Lock() + defer c.Unlock() + if c.closed { + return drpc.ClosedError.New("in-memory coordinator client closed") + } + select { + case c.reqs <- request: + return nil + case <-c.ctx.Done(): + return drpc.ClosedError.New("in-memory coordinator client closed") + } +} + +func (c *inMemoryCoordClient) Recv() (*proto.CoordinateResponse, error) { + select { + case resp, ok := <-c.resps: + if ok { + return resp, nil + } + // response from Coordinator was closed, so close the send direction as well, so that the + // Coordinator won't be waiting for us while shutting down. + _ = c.Close() + return nil, io.EOF + case <-c.ctx.Done(): + return nil, drpc.ClosedError.New("in-memory coord client closed") + } +} + +// NewInMemoryCoordinatorClient creates a coordination client that uses channels to connect to a +// local Coordinator. (The typical alternative is a DRPC-based client.) +func NewInMemoryCoordinatorClient( + logger slog.Logger, + clientID, agentID uuid.UUID, + coordinator Coordinator, +) CoordinatorClient { + logger = logger.With(slog.F("agent_id", agentID), slog.F("client_id", clientID)) + auth := ClientCoordinateeAuth{AgentID: agentID} + c := &inMemoryCoordClient{logger: logger} + c.ctx, c.cancel = context.WithCancel(context.Background()) + + // use the background context since we will depend exclusively on closing the req channel to + // tell the coordinator we are done. + c.reqs, c.resps = coordinator.Coordinate(context.Background(), + clientID, fmt.Sprintf("inmemory%s", clientID), + auth, + ) + return c +} diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go new file mode 100644 index 0000000000..2e3098f80f --- /dev/null +++ b/tailnet/controllers_test.go @@ -0,0 +1,283 @@ +package tailnet_test + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" +) + +func TestInMemoryCoordination(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). + Times(1).Return(reqs, resps) + + ctrl := tailnet.NewSingleDestController(logger, fConn, agentID) + uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, mCoord)) + defer uut.Close(ctx) + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + // Recv loop should be terminated by the server hanging up after Disconnect + err := testutil.RequireRecvCtx(ctx, t, uut.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestSingleDestController(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). + Times(1).Return(reqs, resps) + + var coord tailnet.Coordinator = mCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger.Named("svc"), + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Hour, + DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, + ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), + }) + require.NoError(t, err) + sC, cC := net.Pipe() + + serveErr := make(chan error, 1) + go func() { + err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientCoordinateeAuth{ + AgentID: agentID, + }, + }) + serveErr <- err + }() + + client, err := tailnet.NewDRPCClient(cC, logger) + require.NoError(t, err) + protocol, err := client.Coordinate(ctx) + require.NoError(t, err) + + ctrl := tailnet.NewSingleDestController(logger.Named("coordination"), fConn, agentID) + uut := ctrl.New(protocol) + defer uut.Close(ctx) + + coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) + + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + clientID := uuid.UUID{1} + agentID := uuid.UUID{2} + mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) + fConn := &fakeCoordinatee{} + + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). + Times(1).Return(reqs, resps) + + var coord tailnet.Coordinator = mCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ + Logger: logger.Named("svc"), + CoordPtr: &coordPtr, + DERPMapUpdateFrequency: time.Hour, + DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, + NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, + ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), + }) + require.NoError(t, err) + sC, cC := net.Pipe() + + serveErr := make(chan error, 1) + go func() { + err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ + Name: "client", + ID: clientID, + Auth: tailnet.ClientCoordinateeAuth{ + AgentID: agentID, + }, + }) + serveErr <- err + }() + + client, err := tailnet.NewDRPCClient(cC, logger) + require.NoError(t, err) + protocol, err := client.Coordinate(ctx) + require.NoError(t, err) + + ctrl := tailnet.NewAgentCoordinationController(logger.Named("coordination"), fConn) + uut := ctrl.New(protocol) + defer uut.Close(ctx) + + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{ + PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{ + Id: clientID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{ + Id: 3, + Key: nk, + Disco: string(dk), + }, + }}, + }) + + rfh := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, rfh.ReadyForHandshake) + require.Len(t, rfh.ReadyForHandshake, 1) + require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id) + + go uut.Close(ctx) + dis := testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, dis) + require.NotNil(t, dis.Disconnect) + close(resps) + + // Recv loop should be terminated by the server hanging up after Disconnect + err = testutil.RequireRecvCtx(ctx, t, uut.Wait()) + require.ErrorIs(t, err, io.EOF) +} + +// coordinationTest tests that a coordination behaves correctly +func coordinationTest( + ctx context.Context, t *testing.T, + uut tailnet.CloserWaiter, fConn *fakeCoordinatee, + reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse, + agentID uuid.UUID, +) { + // It should add the tunnel, since we configured as a client + req := testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) + + // when we call the callback, it should send a node update + require.NotNil(t, fConn.callback) + fConn.callback(&tailnet.Node{PreferredDERP: 1}) + + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp()) + + // When we send a peer update, it should update the coordinatee + nk, err := key.NewNode().Public().MarshalBinary() + require.NoError(t, err) + dk, err := key.NewDisco().Public().MarshalText() + require.NoError(t, err) + updates := []*proto.CoordinateResponse_PeerUpdate{ + { + Id: agentID[:], + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{ + Id: 2, + Key: nk, + Disco: string(dk), + }, + }, + } + testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates}) + require.Eventually(t, func() bool { + fConn.Lock() + defer fConn.Unlock() + return len(fConn.updates) > 0 + }, testutil.WaitShort, testutil.IntervalFast) + require.Len(t, fConn.updates[0], 1) + require.Equal(t, agentID[:], fConn.updates[0][0].Id) + + errCh := make(chan error, 1) + go func() { + errCh <- uut.Close(ctx) + }() + + // When we close, it should gracefully disconnect + req = testutil.RequireRecvCtx(ctx, t, reqs) + require.NotNil(t, req.Disconnect) + close(resps) + + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.NoError(t, err) + + // It should set all peers lost on the coordinatee + require.Equal(t, 1, fConn.setAllPeersLostCalls) +} + +type fakeCoordinatee struct { + sync.Mutex + callback func(*tailnet.Node) + updates [][]*proto.CoordinateResponse_PeerUpdate + setAllPeersLostCalls int + tunnelDestinations map[uuid.UUID]struct{} +} + +func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { + f.Lock() + defer f.Unlock() + f.updates = append(f.updates, updates) + return nil +} + +func (f *fakeCoordinatee) SetAllPeersLost() { + f.Lock() + defer f.Unlock() + f.setAllPeersLostCalls++ +} + +func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) { + f.Lock() + defer f.Unlock() + + if f.tunnelDestinations == nil { + f.tunnelDestinations = map[uuid.UUID]struct{}{} + } + f.tunnelDestinations[id] = struct{}{} +} + +func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) { + f.Lock() + defer f.Unlock() + f.callback = callback +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index b059259895..d883ca1b4c 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -90,302 +90,6 @@ type Coordinatee interface { SetTunnelDestination(id uuid.UUID) } -type Coordination interface { - Close(context.Context) error - Error() <-chan error -} - -type remoteCoordination struct { - sync.Mutex - closed bool - errChan chan error - coordinatee Coordinatee - tgt uuid.UUID - logger slog.Logger - protocol proto.DRPCTailnet_CoordinateClient - respLoopDone chan struct{} -} - -// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and -// waiting for the server to hang up the coordination. If the provided context expires, we stop -// waiting for the server and close the coordination stream from our end. -func (c *remoteCoordination) Close(ctx context.Context) (retErr error) { - c.Lock() - defer c.Unlock() - if c.closed { - return nil - } - c.closed = true - defer func() { - // We shouldn't just close the protocol right away, because the way dRPC streams work is - // that if you close them, that could take effect immediately, even before the Disconnect - // message is processed. Coordinators are supposed to hang up on us once they get a - // Disconnect message, so we should wait around for that until the context expires. - select { - case <-c.respLoopDone: - c.logger.Debug(ctx, "responses closed after disconnect") - return - case <-ctx.Done(): - c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close") - } - // forcefully close the stream - protoErr := c.protocol.Close() - <-c.respLoopDone - if retErr == nil { - retErr = protoErr - } - }() - err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}) - if err != nil && !xerrors.Is(err, io.EOF) { - // Coordinator RPC hangs up when it gets disconnect, so EOF is expected. - return xerrors.Errorf("send disconnect: %w", err) - } - c.logger.Debug(context.Background(), "sent disconnect") - return nil -} - -func (c *remoteCoordination) Error() <-chan error { - return c.errChan -} - -func (c *remoteCoordination) sendErr(err error) { - select { - case c.errChan <- err: - default: - } -} - -func (c *remoteCoordination) respLoop() { - defer func() { - c.coordinatee.SetAllPeersLost() - close(c.respLoopDone) - }() - for { - resp, err := c.protocol.Recv() - if err != nil { - c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err)) - c.sendErr(xerrors.Errorf("read: %w", err)) - return - } - - err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) - if err != nil { - c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err)) - c.sendErr(xerrors.Errorf("update peers: %w", err)) - return - } - - // Only send acks from peers without a target. - if c.tgt == uuid.Nil { - // Send an ack back for all received peers. This could - // potentially be smarter to only send an ACK once per client, - // but there's nothing currently stopping clients from reusing - // IDs. - rfh := []*proto.CoordinateRequest_ReadyForHandshake{} - for _, peer := range resp.GetPeerUpdates() { - if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE { - continue - } - - rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id}) - } - if len(rfh) > 0 { - err := c.protocol.Send(&proto.CoordinateRequest{ - ReadyForHandshake: rfh, - }) - if err != nil { - c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err)) - c.sendErr(xerrors.Errorf("send: %w", err)) - return - } - } - } - } -} - -// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a -// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as -// a client---agents should NOT set this!). -func NewRemoteCoordination(logger slog.Logger, - protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee, - tunnelTarget uuid.UUID, -) Coordination { - c := &remoteCoordination{ - errChan: make(chan error, 1), - coordinatee: coordinatee, - tgt: tunnelTarget, - logger: logger, - protocol: protocol, - respLoopDone: make(chan struct{}), - } - if tunnelTarget != uuid.Nil { - c.coordinatee.SetTunnelDestination(tunnelTarget) - c.Lock() - err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}}) - c.Unlock() - if err != nil { - c.sendErr(err) - } - } - - coordinatee.SetNodeCallback(func(node *Node) { - pn, err := NodeToProto(node) - if err != nil { - c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) - c.sendErr(err) - return - } - c.Lock() - defer c.Unlock() - if c.closed { - c.logger.Debug(context.Background(), "ignored node update because coordination is closed") - return - } - err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}) - if err != nil { - c.sendErr(xerrors.Errorf("write: %w", err)) - } - }) - go c.respLoop() - return c -} - -type inMemoryCoordination struct { - sync.Mutex - ctx context.Context - errChan chan error - closed bool - respLoopDone chan struct{} - coordinatee Coordinatee - logger slog.Logger - resps <-chan *proto.CoordinateResponse - reqs chan<- *proto.CoordinateRequest -} - -func (c *inMemoryCoordination) sendErr(err error) { - select { - case c.errChan <- err: - default: - } -} - -func (c *inMemoryCoordination) Error() <-chan error { - return c.errChan -} - -// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing -// or local clients. Set ClientID to uuid.Nil for an agent. -func NewInMemoryCoordination( - ctx context.Context, logger slog.Logger, - clientID, agentID uuid.UUID, - coordinator Coordinator, coordinatee Coordinatee, -) Coordination { - thisID := agentID - logger = logger.With(slog.F("agent_id", agentID)) - var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID} - if clientID != uuid.Nil { - // this is a client connection - auth = ClientCoordinateeAuth{AgentID: agentID} - logger = logger.With(slog.F("client_id", clientID)) - thisID = clientID - } - c := &inMemoryCoordination{ - ctx: ctx, - errChan: make(chan error, 1), - coordinatee: coordinatee, - logger: logger, - respLoopDone: make(chan struct{}), - } - - // use the background context since we will depend exclusively on closing the req channel to - // tell the coordinator we are done. - c.reqs, c.resps = coordinator.Coordinate(context.Background(), - thisID, fmt.Sprintf("inmemory%s", thisID), - auth, - ) - go c.respLoop() - if agentID != uuid.Nil { - select { - case <-ctx.Done(): - c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err())) - return c - case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}: - // OK! - } - } - coordinatee.SetNodeCallback(func(n *Node) { - pn, err := NodeToProto(n) - if err != nil { - c.logger.Critical(ctx, "failed to convert node", slog.Error(err)) - c.sendErr(err) - return - } - c.Lock() - defer c.Unlock() - if c.closed { - return - } - select { - case <-ctx.Done(): - c.logger.Info(ctx, "context expired before sending node update") - return - case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}: - c.logger.Debug(ctx, "sent node in-memory to coordinator") - } - }) - return c -} - -func (c *inMemoryCoordination) respLoop() { - defer func() { - c.coordinatee.SetAllPeersLost() - close(c.respLoopDone) - }() - for resp := range c.resps { - c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp)) - err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates()) - if err != nil { - c.sendErr(xerrors.Errorf("failed to update peers: %w", err)) - return - } - } - c.logger.Debug(context.Background(), "in-memory response channel closed") -} - -func (*inMemoryCoordination) AwaitAck() <-chan struct{} { - // This is only used for tests, so just return a closed channel. - ch := make(chan struct{}) - close(ch) - return ch -} - -// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and -// waiting for the server to hang up the coordination. If the provided context expires, we stop -// waiting for the server and close the coordination stream from our end. -func (c *inMemoryCoordination) Close(ctx context.Context) error { - c.Lock() - defer c.Unlock() - c.logger.Debug(context.Background(), "closing in-memory coordination") - if c.closed { - return nil - } - defer close(c.reqs) - c.closed = true - select { - case <-ctx.Done(): - return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err()) - case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}: - c.logger.Debug(context.Background(), "sent graceful disconnect in-memory") - } - - select { - case <-ctx.Done(): - return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err()) - case <-c.respLoopDone: - return nil - } -} - const LoggerName = "coord" var ( diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index b3a803cd6a..67cf476849 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -2,19 +2,11 @@ package tailnet_test import ( "context" - "io" - "net" "net/netip" - "sync" - "sync/atomic" "testing" - "time" "github.com/google/uuid" "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "tailscale.com/tailcfg" - "tailscale.com/types/key" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" @@ -271,265 +263,6 @@ func TestCoordinator_MultiAgent_CoordClose(t *testing.T) { ma1.RequireEventuallyClosed(ctx) } -func TestInMemoryCoordination(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clientID := uuid.UUID{1} - agentID := uuid.UUID{2} - mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) - fConn := &fakeCoordinatee{} - - reqs := make(chan *proto.CoordinateRequest, 100) - resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). - Times(1).Return(reqs, resps) - - uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn) - defer uut.Close(ctx) - - coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) - - select { - case err := <-uut.Error(): - require.NoError(t, err) - default: - // OK! - } -} - -func TestRemoteCoordination(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clientID := uuid.UUID{1} - agentID := uuid.UUID{2} - mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) - fConn := &fakeCoordinatee{} - - reqs := make(chan *proto.CoordinateRequest, 100) - resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). - Times(1).Return(reqs, resps) - - var coord tailnet.Coordinator = mCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger.Named("svc"), - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Hour, - DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, - ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), - }) - require.NoError(t, err) - sC, cC := net.Pipe() - - serveErr := make(chan error, 1) - go func() { - err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ - Name: "client", - ID: clientID, - Auth: tailnet.ClientCoordinateeAuth{ - AgentID: agentID, - }, - }) - serveErr <- err - }() - - client, err := tailnet.NewDRPCClient(cC, logger) - require.NoError(t, err) - protocol, err := client.Coordinate(ctx) - require.NoError(t, err) - - uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID) - defer uut.Close(ctx) - - coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID) - - // Recv loop should be terminated by the server hanging up after Disconnect - err = testutil.RequireRecvCtx(ctx, t, uut.Error()) - require.ErrorIs(t, err, io.EOF) -} - -func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - clientID := uuid.UUID{1} - agentID := uuid.UUID{2} - mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t)) - fConn := &fakeCoordinatee{} - - reqs := make(chan *proto.CoordinateRequest, 100) - resps := make(chan *proto.CoordinateResponse, 100) - mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}). - Times(1).Return(reqs, resps) - - var coord tailnet.Coordinator = mCoord - coordPtr := atomic.Pointer[tailnet.Coordinator]{} - coordPtr.Store(&coord) - svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ - Logger: logger.Named("svc"), - CoordPtr: &coordPtr, - DERPMapUpdateFrequency: time.Hour, - DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") }, - NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") }, - ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), - }) - require.NoError(t, err) - sC, cC := net.Pipe() - - serveErr := make(chan error, 1) - go func() { - err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{ - Name: "client", - ID: clientID, - Auth: tailnet.ClientCoordinateeAuth{ - AgentID: agentID, - }, - }) - serveErr <- err - }() - - client, err := tailnet.NewDRPCClient(cC, logger) - require.NoError(t, err) - protocol, err := client.Coordinate(ctx) - require.NoError(t, err) - - uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{}) - defer uut.Close(ctx) - - nk, err := key.NewNode().Public().MarshalBinary() - require.NoError(t, err) - dk, err := key.NewDisco().Public().MarshalText() - require.NoError(t, err) - testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{ - PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{ - Id: clientID[:], - Kind: proto.CoordinateResponse_PeerUpdate_NODE, - Node: &proto.Node{ - Id: 3, - Key: nk, - Disco: string(dk), - }, - }}, - }) - - rfh := testutil.RequireRecvCtx(ctx, t, reqs) - require.NotNil(t, rfh.ReadyForHandshake) - require.Len(t, rfh.ReadyForHandshake, 1) - require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id) - - go uut.Close(ctx) - dis := testutil.RequireRecvCtx(ctx, t, reqs) - require.NotNil(t, dis) - require.NotNil(t, dis.Disconnect) - close(resps) - - // Recv loop should be terminated by the server hanging up after Disconnect - err = testutil.RequireRecvCtx(ctx, t, uut.Error()) - require.ErrorIs(t, err, io.EOF) -} - -// coordinationTest tests that a coordination behaves correctly -func coordinationTest( - ctx context.Context, t *testing.T, - uut tailnet.Coordination, fConn *fakeCoordinatee, - reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse, - agentID uuid.UUID, -) { - // It should add the tunnel, since we configured as a client - req := testutil.RequireRecvCtx(ctx, t, reqs) - require.Equal(t, agentID[:], req.GetAddTunnel().GetId()) - - // when we call the callback, it should send a node update - require.NotNil(t, fConn.callback) - fConn.callback(&tailnet.Node{PreferredDERP: 1}) - - req = testutil.RequireRecvCtx(ctx, t, reqs) - require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp()) - - // When we send a peer update, it should update the coordinatee - nk, err := key.NewNode().Public().MarshalBinary() - require.NoError(t, err) - dk, err := key.NewDisco().Public().MarshalText() - require.NoError(t, err) - updates := []*proto.CoordinateResponse_PeerUpdate{ - { - Id: agentID[:], - Kind: proto.CoordinateResponse_PeerUpdate_NODE, - Node: &proto.Node{ - Id: 2, - Key: nk, - Disco: string(dk), - }, - }, - } - testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates}) - require.Eventually(t, func() bool { - fConn.Lock() - defer fConn.Unlock() - return len(fConn.updates) > 0 - }, testutil.WaitShort, testutil.IntervalFast) - require.Len(t, fConn.updates[0], 1) - require.Equal(t, agentID[:], fConn.updates[0][0].Id) - - errCh := make(chan error, 1) - go func() { - errCh <- uut.Close(ctx) - }() - - // When we close, it should gracefully disconnect - req = testutil.RequireRecvCtx(ctx, t, reqs) - require.NotNil(t, req.Disconnect) - close(resps) - - err = testutil.RequireRecvCtx(ctx, t, errCh) - require.NoError(t, err) - - // It should set all peers lost on the coordinatee - require.Equal(t, 1, fConn.setAllPeersLostCalls) -} - -type fakeCoordinatee struct { - sync.Mutex - callback func(*tailnet.Node) - updates [][]*proto.CoordinateResponse_PeerUpdate - setAllPeersLostCalls int - tunnelDestinations map[uuid.UUID]struct{} -} - -func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error { - f.Lock() - defer f.Unlock() - f.updates = append(f.updates, updates) - return nil -} - -func (f *fakeCoordinatee) SetAllPeersLost() { - f.Lock() - defer f.Unlock() - f.setAllPeersLostCalls++ -} - -func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) { - f.Lock() - defer f.Unlock() - - if f.tunnelDestinations == nil { - f.tunnelDestinations = map[uuid.UUID]struct{}{} - } - f.tunnelDestinations[id] = struct{}{} -} - -func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) { - f.Lock() - defer f.Unlock() - f.callback = callback -} - // TestCoordinatorPropogatedPeerContext tests that the context for a specific peer // is propogated through to the `Authorizeā€œ method of the coordinatee auth func TestCoordinatorPropogatedPeerContext(t *testing.T) { diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index ff38aec98b..232e7ab027 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -467,7 +467,8 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me _ = conn.Close() }) - coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID) + ctrl := tailnet.NewSingleDestController(logger, conn, peer.ID) + coordination := ctrl.New(coord) t.Cleanup(func() { cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel()