chore: refactor coordination (#15343)

Refactors the way clients of the Tailnet API (clients of the API, which include both workspace "agents" and "clients") interact with the API.  Introduces the idea of abstract "controllers" for each of the RPCs in the API, and implements a Coordination controller by refactoring from `workspacesdk`.

chore re: #14729
This commit is contained in:
Spike Curtis
2024-11-05 13:50:10 +04:00
committed by GitHub
parent 765314ce18
commit 886dcbec84
9 changed files with 658 additions and 578 deletions

View File

@ -1352,7 +1352,8 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
defer close(disconnected) defer close(disconnected)
a.closeMutex.Unlock() a.closeMutex.Unlock()
coordination := tailnet.NewRemoteCoordination(a.logger, coordinate, network, uuid.Nil) ctrl := tailnet.NewAgentCoordinationController(a.logger, network)
coordination := ctrl.New(coordinate)
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
@ -1364,7 +1365,7 @@ func (a *agent) runCoordinator(ctx context.Context, conn drpc.Conn, network *tai
a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err)) a.logger.Warn(ctx, "failed to close remote coordination", slog.Error(err))
} }
return return
case err := <-coordination.Error(): case err := <-coordination.Wait():
errCh <- err errCh <- err
} }
}() }()

View File

@ -1918,10 +1918,8 @@ func TestAgent_UpdatedDERP(t *testing.T) {
testCtx, testCtxCancel := context.WithCancel(context.Background()) testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel) t.Cleanup(testCtxCancel)
clientID := uuid.New() clientID := uuid.New()
coordination := tailnet.NewInMemoryCoordination( ctrl := tailnet.NewSingleDestController(logger, conn, agentID)
testCtx, logger, coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, coordinator))
clientID, agentID,
coordinator, conn)
t.Cleanup(func() { t.Cleanup(func() {
t.Logf("closing coordination %s", name) t.Logf("closing coordination %s", name)
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
@ -2409,10 +2407,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
testCtx, testCtxCancel := context.WithCancel(context.Background()) testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel) t.Cleanup(testCtxCancel)
clientID := uuid.New() clientID := uuid.New()
coordination := tailnet.NewInMemoryCoordination( ctrl := tailnet.NewSingleDestController(logger, conn, metadata.AgentID)
testCtx, logger, coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(
clientID, metadata.AgentID, logger, clientID, metadata.AgentID, coordinator))
coordinator, conn)
t.Cleanup(func() { t.Cleanup(func() {
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort) cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel() defer ccancel()

View File

@ -71,7 +71,6 @@ func NewClient(t testing.TB,
t: t, t: t,
logger: logger.Named("client"), logger: logger.Named("client"),
agentID: agentID, agentID: agentID,
coordinator: coordinator,
server: server, server: server,
fakeAgentAPI: fakeAAPI, fakeAgentAPI: fakeAAPI,
derpMapUpdates: derpMapUpdates, derpMapUpdates: derpMapUpdates,
@ -82,7 +81,6 @@ type Client struct {
t testing.TB t testing.TB
logger slog.Logger logger slog.Logger
agentID uuid.UUID agentID uuid.UUID
coordinator tailnet.Coordinator
server *drpcserver.Server server *drpcserver.Server
fakeAgentAPI *FakeAgentAPI fakeAgentAPI *FakeAgentAPI
LastWorkspaceAgent func() LastWorkspaceAgent func()

View File

@ -66,6 +66,7 @@ type tailnetAPIConnector struct {
clock quartz.Clock clock quartz.Clock
dialOptions *websocket.DialOptions dialOptions *websocket.DialOptions
conn tailnetConn conn tailnetConn
coordCtrl tailnet.CoordinationController
customDialFn func() (proto.DRPCTailnetClient, error) customDialFn func() (proto.DRPCTailnetClient, error)
clientMu sync.RWMutex clientMu sync.RWMutex
@ -112,6 +113,7 @@ func (tac *tailnetAPIConnector) manageGracefulTimeout() {
// Runs a tailnetAPIConnector using the provided connection // Runs a tailnetAPIConnector using the provided connection
func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) { func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
tac.conn = conn tac.conn = conn
tac.coordCtrl = tailnet.NewSingleDestController(tac.logger, conn, tac.agentID)
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background()) tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
go tac.manageGracefulTimeout() go tac.manageGracefulTimeout()
go func() { go func() {
@ -272,7 +274,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr)) tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
} }
}() }()
coordination := tailnet.NewRemoteCoordination(tac.logger, coord, tac.conn, tac.agentID) coordination := tac.coordCtrl.New(coord)
tac.logger.Debug(tac.ctx, "serving coordinator") tac.logger.Debug(tac.ctx, "serving coordinator")
select { select {
case <-tac.ctx.Done(): case <-tac.ctx.Done():
@ -281,7 +283,7 @@ func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
if crdErr != nil { if crdErr != nil {
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err)) tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
} }
case err = <-coordination.Error(): case err = <-coordination.Wait():
if err != nil && if err != nil &&
!xerrors.Is(err, io.EOF) && !xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.Canceled) &&

361
tailnet/controllers.go Normal file
View File

@ -0,0 +1,361 @@
package tailnet
import (
"context"
"fmt"
"io"
"sync"
"github.com/google/uuid"
"golang.org/x/xerrors"
"storj.io/drpc"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet/proto"
)
// A Controller connects to the tailnet control plane, and then uses the control protocols to
// program a tailnet.Conn in production (in test it could be an interface simulating the Conn). It
// delegates this task to sub-controllers responsible for the main areas of the tailnet control
// protocol: coordination, DERP map updates, resume tokens, and telemetry.
type Controller struct {
Dialer ControlProtocolDialer
CoordCtrl CoordinationController
DERPCtrl DERPController
ResumeTokenCtrl ResumeTokenController
TelemetryCtrl TelemetryController
}
type CloserWaiter interface {
Close(context.Context) error
Wait() <-chan error
}
// CoordinatorClient is an abstraction of the Coordinator's control protocol interface from the
// perspective of a protocol client (i.e. the Coder Agent is also a client of this interface).
type CoordinatorClient interface {
Close() error
Send(*proto.CoordinateRequest) error
Recv() (*proto.CoordinateResponse, error)
}
// A CoordinationController accepts connections to the control plane, and handles the Coordination
// protocol on behalf of some Coordinatee (tailnet.Conn in production). This is the "glue" code
// between them.
type CoordinationController interface {
New(CoordinatorClient) CloserWaiter
}
// DERPClient is an abstraction of the stream of DERPMap updates from the control plane.
type DERPClient interface {
Close() error
Recv() (*tailcfg.DERPMap, error)
}
// A DERPController accepts connections to the control plane, and handles the DERPMap updates
// delivered over them by programming the data plane (tailnet.Conn or some test interface).
type DERPController interface {
New(DERPClient) CloserWaiter
}
type ResumeTokenClient interface {
RefreshResumeToken(ctx context.Context, in *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error)
}
type ResumeTokenController interface {
New(ResumeTokenClient) CloserWaiter
Token() (string, bool)
}
type TelemetryClient interface {
PostTelemetry(ctx context.Context, in *proto.TelemetryRequest) (*proto.TelemetryResponse, error)
}
type TelemetryController interface {
New(TelemetryClient)
}
// ControlProtocolClients represents an abstract interface to the tailnet control plane via a set
// of protocol clients. The Closer should close all the clients (e.g. by closing the underlying
// connection).
type ControlProtocolClients struct {
Closer io.Closer
Coordinator CoordinatorClient
DERP DERPClient
ResumeToken ResumeTokenClient
Telemetry TelemetryClient
}
type ControlProtocolDialer interface {
// Dial connects to the tailnet control plane and returns clients for the different control
// sub-protocols (coordination, DERP maps, resume tokens, and telemetry). If the
// ResumeTokenController is not nil, the dialer should query for a resume token and use it to
// dial, if available.
Dial(ctx context.Context, r ResumeTokenController) (ControlProtocolClients, error)
}
// basicCoordinationController handles the basic coordination operations common to all types of
// tailnet consumers:
//
// 1. sending local node updates to the Coordinator
// 2. receiving peer node updates and programming them into the Coordinatee (e.g. tailnet.Conn)
// 3. (optionally) sending ReadyToHandshake acknowledgements for peer updates.
type basicCoordinationController struct {
logger slog.Logger
coordinatee Coordinatee
sendAcks bool
}
func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter {
b := &basicCoordination{
logger: c.logger,
errChan: make(chan error, 1),
coordinatee: c.coordinatee,
client: client,
respLoopDone: make(chan struct{}),
sendAcks: c.sendAcks,
}
c.coordinatee.SetNodeCallback(func(node *Node) {
pn, err := NodeToProto(node)
if err != nil {
b.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
b.sendErr(err)
return
}
b.Lock()
defer b.Unlock()
if b.closed {
b.logger.Debug(context.Background(), "ignored node update because coordination is closed")
return
}
err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
if err != nil {
b.sendErr(xerrors.Errorf("write: %w", err))
}
})
go b.respLoop()
return b
}
type basicCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
client CoordinatorClient
respLoopDone chan struct{}
sendAcks bool
}
func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
defer func() {
// We shouldn't just close the protocol right away, because the way dRPC streams work is
// that if you close them, that could take effect immediately, even before the Disconnect
// message is processed. Coordinators are supposed to hang up on us once they get a
// Disconnect message, so we should wait around for that until the context expires.
select {
case <-c.respLoopDone:
c.logger.Debug(ctx, "responses closed after disconnect")
return
case <-ctx.Done():
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.client.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil && !xerrors.Is(err, io.EOF) {
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
return xerrors.Errorf("send disconnect: %w", err)
}
c.logger.Debug(context.Background(), "sent disconnect")
return nil
}
func (c *basicCoordination) Wait() <-chan error {
return c.errChan
}
func (c *basicCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *basicCoordination) respLoop() {
defer func() {
cErr := c.client.Close()
if cErr != nil {
c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr))
}
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
resp, err := c.client.Recv()
if err != nil {
c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err))
c.sendErr(xerrors.Errorf("read: %w", err))
return
}
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err))
c.sendErr(xerrors.Errorf("update peers: %w", err))
return
}
// Only send ReadyForHandshake acks from peers without a target.
if c.sendAcks {
// Send an ack back for all received peers. This could
// potentially be smarter to only send an ACK once per client,
// but there's nothing currently stopping clients from reusing
// IDs.
rfh := []*proto.CoordinateRequest_ReadyForHandshake{}
for _, peer := range resp.GetPeerUpdates() {
if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
continue
}
rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id})
}
if len(rfh) > 0 {
err := c.client.Send(&proto.CoordinateRequest{
ReadyForHandshake: rfh,
})
if err != nil {
c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err))
c.sendErr(xerrors.Errorf("send: %w", err))
return
}
}
}
}
}
type singleDestController struct {
*basicCoordinationController
dest uuid.UUID
}
// NewSingleDestController creates a CoordinationController for Coder clients that connect to a
// single tunnel destination, e.g. `coder ssh`, which connects to a single workspace Agent.
func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController {
coordinatee.SetTunnelDestination(dest)
return &singleDestController{
basicCoordinationController: &basicCoordinationController{
logger: logger,
coordinatee: coordinatee,
sendAcks: false,
},
dest: dest,
}
}
func (c *singleDestController) New(client CoordinatorClient) CloserWaiter {
// nolint: forcetypeassert
b := c.basicCoordinationController.New(client).(*basicCoordination)
err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}})
if err != nil {
b.sendErr(err)
}
return b
}
// NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never
// create tunnels and always send ReadyToHandshake acknowledgements.
func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController {
return &basicCoordinationController{
logger: logger,
coordinatee: coordinatee,
sendAcks: true,
}
}
type inMemoryCoordClient struct {
sync.Mutex
ctx context.Context
cancel context.CancelFunc
closed bool
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}
func (c *inMemoryCoordClient) Close() error {
c.cancel()
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.reqs)
return nil
}
func (c *inMemoryCoordClient) Send(request *proto.CoordinateRequest) error {
c.Lock()
defer c.Unlock()
if c.closed {
return drpc.ClosedError.New("in-memory coordinator client closed")
}
select {
case c.reqs <- request:
return nil
case <-c.ctx.Done():
return drpc.ClosedError.New("in-memory coordinator client closed")
}
}
func (c *inMemoryCoordClient) Recv() (*proto.CoordinateResponse, error) {
select {
case resp, ok := <-c.resps:
if ok {
return resp, nil
}
// response from Coordinator was closed, so close the send direction as well, so that the
// Coordinator won't be waiting for us while shutting down.
_ = c.Close()
return nil, io.EOF
case <-c.ctx.Done():
return nil, drpc.ClosedError.New("in-memory coord client closed")
}
}
// NewInMemoryCoordinatorClient creates a coordination client that uses channels to connect to a
// local Coordinator. (The typical alternative is a DRPC-based client.)
func NewInMemoryCoordinatorClient(
logger slog.Logger,
clientID, agentID uuid.UUID,
coordinator Coordinator,
) CoordinatorClient {
logger = logger.With(slog.F("agent_id", agentID), slog.F("client_id", clientID))
auth := ClientCoordinateeAuth{AgentID: agentID}
c := &inMemoryCoordClient{logger: logger}
c.ctx, c.cancel = context.WithCancel(context.Background())
// use the background context since we will depend exclusively on closing the req channel to
// tell the coordinator we are done.
c.reqs, c.resps = coordinator.Coordinate(context.Background(),
clientID, fmt.Sprintf("inmemory%s", clientID),
auth,
)
return c
}

283
tailnet/controllers_test.go Normal file
View File

@ -0,0 +1,283 @@
package tailnet_test
import (
"context"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestInMemoryCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
ctrl := tailnet.NewSingleDestController(logger, fConn, agentID)
uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, agentID, mCoord))
defer uut.Close(ctx)
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
// Recv loop should be terminated by the server hanging up after Disconnect
err := testutil.RequireRecvCtx(ctx, t, uut.Wait())
require.ErrorIs(t, err, io.EOF)
}
func TestSingleDestController(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger.Named("svc"),
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientCoordinateeAuth{
AgentID: agentID,
},
})
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
ctrl := tailnet.NewSingleDestController(logger.Named("coordination"), fConn, agentID)
uut := ctrl.New(protocol)
defer uut.Close(ctx)
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Wait())
require.ErrorIs(t, err, io.EOF)
}
func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger.Named("svc"),
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientCoordinateeAuth{
AgentID: agentID,
},
})
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
ctrl := tailnet.NewAgentCoordinationController(logger.Named("coordination"), fConn)
uut := ctrl.New(protocol)
defer uut.Close(ctx)
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
Id: clientID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 3,
Key: nk,
Disco: string(dk),
},
}},
})
rfh := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, rfh.ReadyForHandshake)
require.Len(t, rfh.ReadyForHandshake, 1)
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
go uut.Close(ctx)
dis := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, dis)
require.NotNil(t, dis.Disconnect)
close(resps)
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Wait())
require.ErrorIs(t, err, io.EOF)
}
// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.CloserWaiter, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
errCh := make(chan error, 1)
go func() {
errCh <- uut.Close(ctx)
}()
// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
close(resps)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
}
type fakeCoordinatee struct {
sync.Mutex
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
tunnelDestinations map[uuid.UUID]struct{}
}
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
f.Lock()
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
}
func (f *fakeCoordinatee) SetAllPeersLost() {
f.Lock()
defer f.Unlock()
f.setAllPeersLostCalls++
}
func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) {
f.Lock()
defer f.Unlock()
if f.tunnelDestinations == nil {
f.tunnelDestinations = map[uuid.UUID]struct{}{}
}
f.tunnelDestinations[id] = struct{}{}
}
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
f.Lock()
defer f.Unlock()
f.callback = callback
}

View File

@ -90,302 +90,6 @@ type Coordinatee interface {
SetTunnelDestination(id uuid.UUID) SetTunnelDestination(id uuid.UUID)
} }
type Coordination interface {
Close(context.Context) error
Error() <-chan error
}
type remoteCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
tgt uuid.UUID
logger slog.Logger
protocol proto.DRPCTailnet_CoordinateClient
respLoopDone chan struct{}
}
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
// waiting for the server to hang up the coordination. If the provided context expires, we stop
// waiting for the server and close the coordination stream from our end.
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
defer func() {
// We shouldn't just close the protocol right away, because the way dRPC streams work is
// that if you close them, that could take effect immediately, even before the Disconnect
// message is processed. Coordinators are supposed to hang up on us once they get a
// Disconnect message, so we should wait around for that until the context expires.
select {
case <-c.respLoopDone:
c.logger.Debug(ctx, "responses closed after disconnect")
return
case <-ctx.Done():
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.protocol.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil && !xerrors.Is(err, io.EOF) {
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
return xerrors.Errorf("send disconnect: %w", err)
}
c.logger.Debug(context.Background(), "sent disconnect")
return nil
}
func (c *remoteCoordination) Error() <-chan error {
return c.errChan
}
func (c *remoteCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *remoteCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
resp, err := c.protocol.Recv()
if err != nil {
c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err))
c.sendErr(xerrors.Errorf("read: %w", err))
return
}
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err))
c.sendErr(xerrors.Errorf("update peers: %w", err))
return
}
// Only send acks from peers without a target.
if c.tgt == uuid.Nil {
// Send an ack back for all received peers. This could
// potentially be smarter to only send an ACK once per client,
// but there's nothing currently stopping clients from reusing
// IDs.
rfh := []*proto.CoordinateRequest_ReadyForHandshake{}
for _, peer := range resp.GetPeerUpdates() {
if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
continue
}
rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id})
}
if len(rfh) > 0 {
err := c.protocol.Send(&proto.CoordinateRequest{
ReadyForHandshake: rfh,
})
if err != nil {
c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err))
c.sendErr(xerrors.Errorf("send: %w", err))
return
}
}
}
}
}
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
// a client---agents should NOT set this!).
func NewRemoteCoordination(logger slog.Logger,
protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee,
tunnelTarget uuid.UUID,
) Coordination {
c := &remoteCoordination{
errChan: make(chan error, 1),
coordinatee: coordinatee,
tgt: tunnelTarget,
logger: logger,
protocol: protocol,
respLoopDone: make(chan struct{}),
}
if tunnelTarget != uuid.Nil {
c.coordinatee.SetTunnelDestination(tunnelTarget)
c.Lock()
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
c.Unlock()
if err != nil {
c.sendErr(err)
}
}
coordinatee.SetNodeCallback(func(node *Node) {
pn, err := NodeToProto(node)
if err != nil {
c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
c.sendErr(err)
return
}
c.Lock()
defer c.Unlock()
if c.closed {
c.logger.Debug(context.Background(), "ignored node update because coordination is closed")
return
}
err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
if err != nil {
c.sendErr(xerrors.Errorf("write: %w", err))
}
})
go c.respLoop()
return c
}
type inMemoryCoordination struct {
sync.Mutex
ctx context.Context
errChan chan error
closed bool
respLoopDone chan struct{}
coordinatee Coordinatee
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}
func (c *inMemoryCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *inMemoryCoordination) Error() <-chan error {
return c.errChan
}
// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing
// or local clients. Set ClientID to uuid.Nil for an agent.
func NewInMemoryCoordination(
ctx context.Context, logger slog.Logger,
clientID, agentID uuid.UUID,
coordinator Coordinator, coordinatee Coordinatee,
) Coordination {
thisID := agentID
logger = logger.With(slog.F("agent_id", agentID))
var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID}
if clientID != uuid.Nil {
// this is a client connection
auth = ClientCoordinateeAuth{AgentID: agentID}
logger = logger.With(slog.F("client_id", clientID))
thisID = clientID
}
c := &inMemoryCoordination{
ctx: ctx,
errChan: make(chan error, 1),
coordinatee: coordinatee,
logger: logger,
respLoopDone: make(chan struct{}),
}
// use the background context since we will depend exclusively on closing the req channel to
// tell the coordinator we are done.
c.reqs, c.resps = coordinator.Coordinate(context.Background(),
thisID, fmt.Sprintf("inmemory%s", thisID),
auth,
)
go c.respLoop()
if agentID != uuid.Nil {
select {
case <-ctx.Done():
c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err()))
return c
case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}:
// OK!
}
}
coordinatee.SetNodeCallback(func(n *Node) {
pn, err := NodeToProto(n)
if err != nil {
c.logger.Critical(ctx, "failed to convert node", slog.Error(err))
c.sendErr(err)
return
}
c.Lock()
defer c.Unlock()
if c.closed {
return
}
select {
case <-ctx.Done():
c.logger.Info(ctx, "context expired before sending node update")
return
case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}:
c.logger.Debug(ctx, "sent node in-memory to coordinator")
}
})
return c
}
func (c *inMemoryCoordination) respLoop() {
defer func() {
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for resp := range c.resps {
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
return
}
}
c.logger.Debug(context.Background(), "in-memory response channel closed")
}
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
// This is only used for tests, so just return a closed channel.
ch := make(chan struct{})
close(ch)
return ch
}
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
// waiting for the server to hang up the coordination. If the provided context expires, we stop
// waiting for the server and close the coordination stream from our end.
func (c *inMemoryCoordination) Close(ctx context.Context) error {
c.Lock()
defer c.Unlock()
c.logger.Debug(context.Background(), "closing in-memory coordination")
if c.closed {
return nil
}
defer close(c.reqs)
c.closed = true
select {
case <-ctx.Done():
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
}
select {
case <-ctx.Done():
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
case <-c.respLoopDone:
return nil
}
}
const LoggerName = "coord" const LoggerName = "coord"
var ( var (

View File

@ -2,19 +2,11 @@ package tailnet_test
import ( import (
"context" "context"
"io"
"net"
"net/netip" "net/netip"
"sync"
"sync/atomic"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
@ -271,265 +263,6 @@ func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
ma1.RequireEventuallyClosed(ctx) ma1.RequireEventuallyClosed(ctx)
} }
func TestInMemoryCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
uut := tailnet.NewInMemoryCoordination(ctx, logger, clientID, agentID, mCoord, fConn)
defer uut.Close(ctx)
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
select {
case err := <-uut.Error():
require.NoError(t, err)
default:
// OK!
}
}
func TestRemoteCoordination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger.Named("svc"),
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientCoordinateeAuth{
AgentID: agentID,
},
})
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, agentID)
defer uut.Close(ctx)
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
require.ErrorIs(t, err, io.EOF)
}
func TestRemoteCoordination_SendsReadyForHandshake(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientID := uuid.UUID{1}
agentID := uuid.UUID{2}
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
fConn := &fakeCoordinatee{}
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
Times(1).Return(reqs, resps)
var coord tailnet.Coordinator = mCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger.Named("svc"),
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
sC, cC := net.Pipe()
serveErr := make(chan error, 1)
go func() {
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{
Name: "client",
ID: clientID,
Auth: tailnet.ClientCoordinateeAuth{
AgentID: agentID,
},
})
serveErr <- err
}()
client, err := tailnet.NewDRPCClient(cC, logger)
require.NoError(t, err)
protocol, err := client.Coordinate(ctx)
require.NoError(t, err)
uut := tailnet.NewRemoteCoordination(logger.Named("coordination"), protocol, fConn, uuid.UUID{})
defer uut.Close(ctx)
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
Id: clientID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 3,
Key: nk,
Disco: string(dk),
},
}},
})
rfh := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, rfh.ReadyForHandshake)
require.Len(t, rfh.ReadyForHandshake, 1)
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
go uut.Close(ctx)
dis := testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, dis)
require.NotNil(t, dis.Disconnect)
close(resps)
// Recv loop should be terminated by the server hanging up after Disconnect
err = testutil.RequireRecvCtx(ctx, t, uut.Error())
require.ErrorIs(t, err, io.EOF)
}
// coordinationTest tests that a coordination behaves correctly
func coordinationTest(
ctx context.Context, t *testing.T,
uut tailnet.Coordination, fConn *fakeCoordinatee,
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
agentID uuid.UUID,
) {
// It should add the tunnel, since we configured as a client
req := testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
// when we call the callback, it should send a node update
require.NotNil(t, fConn.callback)
fConn.callback(&tailnet.Node{PreferredDERP: 1})
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
// When we send a peer update, it should update the coordinatee
nk, err := key.NewNode().Public().MarshalBinary()
require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err)
updates := []*proto.CoordinateResponse_PeerUpdate{
{
Id: agentID[:],
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{
Id: 2,
Key: nk,
Disco: string(dk),
},
},
}
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
require.Eventually(t, func() bool {
fConn.Lock()
defer fConn.Unlock()
return len(fConn.updates) > 0
}, testutil.WaitShort, testutil.IntervalFast)
require.Len(t, fConn.updates[0], 1)
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
errCh := make(chan error, 1)
go func() {
errCh <- uut.Close(ctx)
}()
// When we close, it should gracefully disconnect
req = testutil.RequireRecvCtx(ctx, t, reqs)
require.NotNil(t, req.Disconnect)
close(resps)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
// It should set all peers lost on the coordinatee
require.Equal(t, 1, fConn.setAllPeersLostCalls)
}
type fakeCoordinatee struct {
sync.Mutex
callback func(*tailnet.Node)
updates [][]*proto.CoordinateResponse_PeerUpdate
setAllPeersLostCalls int
tunnelDestinations map[uuid.UUID]struct{}
}
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
f.Lock()
defer f.Unlock()
f.updates = append(f.updates, updates)
return nil
}
func (f *fakeCoordinatee) SetAllPeersLost() {
f.Lock()
defer f.Unlock()
f.setAllPeersLostCalls++
}
func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) {
f.Lock()
defer f.Unlock()
if f.tunnelDestinations == nil {
f.tunnelDestinations = map[uuid.UUID]struct{}{}
}
f.tunnelDestinations[id] = struct{}{}
}
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
f.Lock()
defer f.Unlock()
f.callback = callback
}
// TestCoordinatorPropogatedPeerContext tests that the context for a specific peer // TestCoordinatorPropogatedPeerContext tests that the context for a specific peer
// is propogated through to the `Authorize“ method of the coordinatee auth // is propogated through to the `Authorize“ method of the coordinatee auth
func TestCoordinatorPropogatedPeerContext(t *testing.T) { func TestCoordinatorPropogatedPeerContext(t *testing.T) {

View File

@ -467,7 +467,8 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me
_ = conn.Close() _ = conn.Close()
}) })
coordination := tailnet.NewRemoteCoordination(logger, coord, conn, peer.ID) ctrl := tailnet.NewSingleDestController(logger, conn, peer.ID)
coordination := ctrl.New(coord)
t.Cleanup(func() { t.Cleanup(func() {
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel() defer cancel()