feat: add support for multiple tunnel destinations in tailnet (#15409)

Closes #14729

Expands the Coordination controller used by the CLI client to allow multiple tunnel destinations (agents).  Our current client uses just one, but this unifies the logic so that when we add Coder VPN, 1 is just a special case of "many."
This commit is contained in:
Spike Curtis
2024-11-08 13:32:07 +04:00
committed by GitHub
parent 8c00ebc6ee
commit e5661c2748
5 changed files with 530 additions and 23 deletions

View File

@ -1918,7 +1918,8 @@ func TestAgent_UpdatedDERP(t *testing.T) {
testCtx, testCtxCancel := context.WithCancel(context.Background()) testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel) t.Cleanup(testCtxCancel)
clientID := uuid.New() clientID := uuid.New()
ctrl := tailnet.NewSingleDestController(logger, conn, agentID) ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
ctrl.AddDestination(agentID)
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID} auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, coordinator)) coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, coordinator))
t.Cleanup(func() { t.Cleanup(func() {
@ -2408,7 +2409,8 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
testCtx, testCtxCancel := context.WithCancel(context.Background()) testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel) t.Cleanup(testCtxCancel)
clientID := uuid.New() clientID := uuid.New()
ctrl := tailnet.NewSingleDestController(logger, conn, metadata.AgentID) ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
ctrl.AddDestination(metadata.AgentID)
auth := tailnet.ClientCoordinateeAuth{AgentID: metadata.AgentID} auth := tailnet.ClientCoordinateeAuth{AgentID: metadata.AgentID}
coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient( coordination := ctrl.New(tailnet.NewInMemoryCoordinatorClient(
logger, clientID, auth, coordinator)) logger, clientID, auth, coordinator))

View File

@ -268,7 +268,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
_ = conn.Close() _ = conn.Close()
} }
}() }()
controller.CoordCtrl = tailnet.NewSingleDestController(options.Logger, conn, agentID) coordCtrl := tailnet.NewTunnelSrcCoordController(options.Logger, conn)
coordCtrl.AddDestination(agentID)
controller.CoordCtrl = coordCtrl
controller.DERPCtrl = tailnet.NewBasicDERPController(options.Logger, conn) controller.DERPCtrl = tailnet.NewBasicDERPController(options.Logger, conn)
controller.Run(ctx) controller.Run(ctx)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"maps"
"math" "math"
"strings" "strings"
"sync" "sync"
@ -239,7 +240,8 @@ func (c *BasicCoordination) respLoop() {
defer func() { defer func() {
cErr := c.Client.Close() cErr := c.Client.Close()
if cErr != nil { if cErr != nil {
c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr)) c.logger.Debug(context.Background(),
"failed to close coordinate client after respLoop exit", slog.Error(cErr))
} }
c.coordinatee.SetAllPeersLost() c.coordinatee.SetAllPeersLost()
close(c.respLoopDone) close(c.respLoopDone)
@ -247,7 +249,8 @@ func (c *BasicCoordination) respLoop() {
for { for {
resp, err := c.Client.Recv() resp, err := c.Client.Recv()
if err != nil { if err != nil {
c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err)) c.logger.Debug(context.Background(),
"failed to read from protocol", slog.Error(err))
c.SendErr(xerrors.Errorf("read: %w", err)) c.SendErr(xerrors.Errorf("read: %w", err))
return return
} }
@ -278,7 +281,8 @@ func (c *BasicCoordination) respLoop() {
ReadyForHandshake: rfh, ReadyForHandshake: rfh,
}) })
if err != nil { if err != nil {
c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err)) c.logger.Debug(context.Background(),
"failed to send ready for handshake", slog.Error(err))
c.SendErr(xerrors.Errorf("send: %w", err)) c.SendErr(xerrors.Errorf("send: %w", err))
return return
} }
@ -287,37 +291,158 @@ func (c *BasicCoordination) respLoop() {
} }
} }
type singleDestController struct { type TunnelSrcCoordController struct {
*BasicCoordinationController *BasicCoordinationController
dest uuid.UUID
mu sync.Mutex
dests map[uuid.UUID]struct{}
coordination *BasicCoordination
} }
// NewSingleDestController creates a CoordinationController for Coder clients that connect to a // NewTunnelSrcCoordController creates a CoordinationController for peers that are exclusively
// single tunnel destination, e.g. `coder ssh`, which connects to a single workspace Agent. // tunnel sources (that is, they create tunnel --- Coder clients not workspaces).
func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController { func NewTunnelSrcCoordController(
coordinatee.SetTunnelDestination(dest) logger slog.Logger, coordinatee Coordinatee,
return &singleDestController{ ) *TunnelSrcCoordController {
return &TunnelSrcCoordController{
BasicCoordinationController: &BasicCoordinationController{ BasicCoordinationController: &BasicCoordinationController{
Logger: logger, Logger: logger,
Coordinatee: coordinatee, Coordinatee: coordinatee,
SendAcks: false, SendAcks: false,
}, },
dest: dest, dests: make(map[uuid.UUID]struct{}),
} }
} }
func (c *singleDestController) New(client CoordinatorClient) CloserWaiter { func (c *TunnelSrcCoordController) New(client CoordinatorClient) CloserWaiter {
c.mu.Lock()
defer c.mu.Unlock()
b := c.BasicCoordinationController.NewCoordination(client) b := c.BasicCoordinationController.NewCoordination(client)
err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}}) c.coordination = b
// resync destinations on reconnect
for dest := range c.dests {
err := client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil { if err != nil {
b.SendErr(err) b.SendErr(err)
c.coordination = nil
cErr := client.Close()
if cErr != nil {
c.Logger.Debug(
context.Background(),
"failed to close coordinator client after add tunnel failure",
slog.Error(cErr),
)
}
break
}
} }
return b return b
} }
func (c *TunnelSrcCoordController) AddDestination(dest uuid.UUID) {
c.mu.Lock()
defer c.mu.Unlock()
c.Coordinatee.SetTunnelDestination(dest) // this prepares us for an ack
c.dests[dest] = struct{}{}
if c.coordination == nil {
return
}
err := c.coordination.Client.Send(
&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil {
c.coordination.SendErr(err)
cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect
if cErr != nil {
c.Logger.Debug(context.Background(),
"failed to close coordinator client after add tunnel failure",
slog.Error(cErr))
}
c.coordination = nil
}
}
func (c *TunnelSrcCoordController) RemoveDestination(dest uuid.UUID) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.dests, dest)
if c.coordination == nil {
return
}
err := c.coordination.Client.Send(
&proto.CoordinateRequest{
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil {
c.coordination.SendErr(err)
cErr := c.coordination.Client.Close() // close the client so we don't gracefully disconnect
if cErr != nil {
c.Logger.Debug(context.Background(),
"failed to close coordinator client after remove tunnel failure",
slog.Error(cErr))
}
c.coordination = nil
}
}
func (c *TunnelSrcCoordController) SyncDestinations(destinations []uuid.UUID) {
c.mu.Lock()
defer c.mu.Unlock()
toAdd := make(map[uuid.UUID]struct{})
toRemove := maps.Clone(c.dests)
all := make(map[uuid.UUID]struct{})
for _, dest := range destinations {
all[dest] = struct{}{}
delete(toRemove, dest)
if _, ok := c.dests[dest]; !ok {
toAdd[dest] = struct{}{}
}
}
c.dests = all
if c.coordination == nil {
return
}
var err error
defer func() {
if err != nil {
c.coordination.SendErr(err)
cErr := c.coordination.Client.Close() // don't gracefully disconnect
if cErr != nil {
c.Logger.Debug(context.Background(),
"failed to close coordinator client during sync destinations",
slog.Error(cErr))
}
c.coordination = nil
}
}()
for dest := range toAdd {
err = c.coordination.Client.Send(
&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil {
return
}
}
for dest := range toRemove {
err = c.coordination.Client.Send(
&proto.CoordinateRequest{
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(dest)},
})
if err != nil {
return
}
}
}
// NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never // NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never
// create tunnels and always send ReadyToHandshake acknowledgements. // create tunnels and always send ReadyToHandshake acknowledgements.
func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController { func NewAgentCoordinationController(
logger slog.Logger, coordinatee Coordinatee,
) CoordinationController {
return &BasicCoordinationController{ return &BasicCoordinationController{
Logger: logger, Logger: logger,
Coordinatee: coordinatee, Coordinatee: coordinatee,

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"io" "io"
"net" "net"
"slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -46,7 +47,8 @@ func TestInMemoryCoordination(t *testing.T) {
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), auth). mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), auth).
Times(1).Return(reqs, resps) Times(1).Return(reqs, resps)
ctrl := tailnet.NewSingleDestController(logger, fConn, agentID) ctrl := tailnet.NewTunnelSrcCoordController(logger, fConn)
ctrl.AddDestination(agentID)
uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, mCoord)) uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, mCoord))
defer uut.Close(ctx) defer uut.Close(ctx)
@ -57,7 +59,7 @@ func TestInMemoryCoordination(t *testing.T) {
require.ErrorIs(t, err, io.EOF) require.ErrorIs(t, err, io.EOF)
} }
func TestSingleDestController(t *testing.T) { func TestTunnelSrcCoordController_Mainline(t *testing.T) {
t.Parallel() t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@ -102,7 +104,8 @@ func TestSingleDestController(t *testing.T) {
protocol, err := client.Coordinate(ctx) protocol, err := client.Coordinate(ctx)
require.NoError(t, err) require.NoError(t, err)
ctrl := tailnet.NewSingleDestController(logger.Named("coordination"), fConn, agentID) ctrl := tailnet.NewTunnelSrcCoordController(logger.Named("coordination"), fConn)
ctrl.AddDestination(agentID)
uut := ctrl.New(protocol) uut := ctrl.New(protocol)
defer uut.Close(ctx) defer uut.Close(ctx)
@ -113,6 +116,284 @@ func TestSingleDestController(t *testing.T) {
require.ErrorIs(t, err, io.EOF) require.ErrorIs(t, err, io.EOF)
} }
func TestTunnelSrcCoordController_AddDestination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fConn := &fakeCoordinatee{}
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
// GIVEN: client already connected
client1 := newFakeCoordinatorClient(ctx, t)
cw1 := uut.New(client1)
// WHEN: we add 2 destinations
dest1 := uuid.UUID{1}
dest2 := uuid.UUID{2}
addDone := make(chan struct{})
go func() {
defer close(addDone)
uut.AddDestination(dest1)
uut.AddDestination(dest2)
}()
// THEN: Controller sends AddTunnel for the destinations
for i := range 2 {
b0 := byte(i + 1)
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
require.Equal(t, b0, call.req.GetAddTunnel().GetId()[0])
testutil.RequireSendCtx(ctx, t, call.err, nil)
}
_ = testutil.RequireRecvCtx(ctx, t, addDone)
// THEN: Controller sets destinations on Coordinatee
require.Contains(t, fConn.tunnelDestinations, dest1)
require.Contains(t, fConn.tunnelDestinations, dest2)
// WHEN: Closed from server side and reconnects
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
require.ErrorIs(t, err, io.EOF)
client2 := newFakeCoordinatorClient(ctx, t)
cws := make(chan tailnet.CloserWaiter)
go func() {
cws <- uut.New(client2)
}()
// THEN: should immediately send both destinations
var dests []byte
for range 2 {
call := testutil.RequireRecvCtx(ctx, t, client2.reqs)
dests = append(dests, call.req.GetAddTunnel().GetId()[0])
testutil.RequireSendCtx(ctx, t, call.err, nil)
}
slices.Sort(dests)
require.Equal(t, dests, []byte{1, 2})
cw2 := testutil.RequireRecvCtx(ctx, t, cws)
// close client2
respCall = testutil.RequireRecvCtx(ctx, t, client2.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall = testutil.RequireRecvCtx(ctx, t, client2.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
require.ErrorIs(t, err, io.EOF)
}
func TestTunnelSrcCoordController_RemoveDestination(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fConn := &fakeCoordinatee{}
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
// GIVEN: 1 destination
dest1 := uuid.UUID{1}
uut.AddDestination(dest1)
// GIVEN: client already connected
client1 := newFakeCoordinatorClient(ctx, t)
cws := make(chan tailnet.CloserWaiter)
go func() {
cws <- uut.New(client1)
}()
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, nil)
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
// WHEN: we remove one destination
removeDone := make(chan struct{})
go func() {
defer close(removeDone)
uut.RemoveDestination(dest1)
}()
// THEN: Controller sends RemoveTunnel for the destination
call = testutil.RequireRecvCtx(ctx, t, client1.reqs)
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, nil)
_ = testutil.RequireRecvCtx(ctx, t, removeDone)
// WHEN: Closed from server side and reconnect
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
require.ErrorIs(t, err, io.EOF)
client2 := newFakeCoordinatorClient(ctx, t)
go func() {
cws <- uut.New(client2)
}()
// THEN: should immediately resolve without sending anything
cw2 := testutil.RequireRecvCtx(ctx, t, cws)
// close client2
respCall = testutil.RequireRecvCtx(ctx, t, client2.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall = testutil.RequireRecvCtx(ctx, t, client2.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
require.ErrorIs(t, err, io.EOF)
}
func TestTunnelSrcCoordController_RemoveDestination_Error(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fConn := &fakeCoordinatee{}
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
// GIVEN: 3 destination
dest1 := uuid.UUID{1}
dest2 := uuid.UUID{2}
dest3 := uuid.UUID{3}
uut.AddDestination(dest1)
uut.AddDestination(dest2)
uut.AddDestination(dest3)
// GIVEN: client already connected
client1 := newFakeCoordinatorClient(ctx, t)
cws := make(chan tailnet.CloserWaiter)
go func() {
cws <- uut.New(client1)
}()
for range 3 {
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, nil)
}
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
// WHEN: we remove all destinations
removeDone := make(chan struct{})
go func() {
defer close(removeDone)
uut.RemoveDestination(dest1)
uut.RemoveDestination(dest2)
uut.RemoveDestination(dest3)
}()
// WHEN: first RemoveTunnel call fails
theErr := xerrors.New("a bad thing happened")
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, theErr)
// THEN: we disconnect and do not send remaining RemoveTunnel messages
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
_ = testutil.RequireRecvCtx(ctx, t, removeDone)
// shut down
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
// triggers second close call
closeCall = testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
require.ErrorIs(t, err, theErr)
}
func TestTunnelSrcCoordController_Sync(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fConn := &fakeCoordinatee{}
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
dest1 := uuid.UUID{1}
dest2 := uuid.UUID{2}
dest3 := uuid.UUID{3}
// GIVEN: dest1 & dest2 already added
uut.AddDestination(dest1)
uut.AddDestination(dest2)
// GIVEN: client already connected
client1 := newFakeCoordinatorClient(ctx, t)
cws := make(chan tailnet.CloserWaiter)
go func() {
cws <- uut.New(client1)
}()
for range 2 {
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, nil)
}
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
// WHEN: we sync dest2 & dest3
syncDone := make(chan struct{})
go func() {
defer close(syncDone)
uut.SyncDestinations([]uuid.UUID{dest2, dest3})
}()
// THEN: we get an add for dest3 and remove for dest1
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
require.Equal(t, dest3[:], call.req.GetAddTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, nil)
call = testutil.RequireRecvCtx(ctx, t, client1.reqs)
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
testutil.RequireSendCtx(ctx, t, call.err, nil)
// shut down
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
require.ErrorIs(t, err, io.EOF)
}
func TestTunnelSrcCoordController_AddDestination_Error(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fConn := &fakeCoordinatee{}
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
// GIVEN: client already connected
client1 := newFakeCoordinatorClient(ctx, t)
cw1 := uut.New(client1)
// WHEN: we add a destination, and the AddTunnel fails
dest1 := uuid.UUID{1}
addDone := make(chan struct{})
go func() {
defer close(addDone)
uut.AddDestination(dest1)
}()
theErr := xerrors.New("a bad thing happened")
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
testutil.RequireSendCtx(ctx, t, call.err, theErr)
// THEN: Client is closed and exits
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
// close the resps, since the client has closed
resp := testutil.RequireRecvCtx(ctx, t, client1.resps)
testutil.RequireSendCtx(ctx, t, resp.err, net.ErrClosed)
// this triggers a second Close() call on the client
closeCall = testutil.RequireRecvCtx(ctx, t, client1.close)
testutil.RequireSendCtx(ctx, t, closeCall, nil)
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
require.ErrorIs(t, err, theErr)
_ = testutil.RequireRecvCtx(ctx, t, addDone)
}
func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) { func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) {
t.Parallel() t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)
@ -885,3 +1166,99 @@ func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (t
Telemetry: client, Telemetry: client,
}, nil }, nil
} }
type fakeCoordinatorClient struct {
ctx context.Context
t testing.TB
reqs chan *coordReqCall
resps chan *coordRespCall
close chan chan<- error
}
func (f fakeCoordinatorClient) Close() error {
f.t.Helper()
errs := make(chan error)
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send close call")
return f.ctx.Err()
case f.close <- errs:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for close call response")
return f.ctx.Err()
case err := <-errs:
return err
}
}
func (f fakeCoordinatorClient) Send(request *proto.CoordinateRequest) error {
f.t.Helper()
errs := make(chan error)
call := &coordReqCall{
req: request,
err: errs,
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send call")
return f.ctx.Err()
case f.reqs <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for send call response")
return f.ctx.Err()
case err := <-errs:
return err
}
}
func (f fakeCoordinatorClient) Recv() (*proto.CoordinateResponse, error) {
f.t.Helper()
resps := make(chan *proto.CoordinateResponse)
errs := make(chan error)
call := &coordRespCall{
resp: resps,
err: errs,
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting to send Recv() call")
return nil, f.ctx.Err()
case f.resps <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timed out waiting for Recv() call response")
return nil, f.ctx.Err()
case err := <-errs:
return nil, err
case resp := <-resps:
return resp, nil
}
}
func newFakeCoordinatorClient(ctx context.Context, t testing.TB) *fakeCoordinatorClient {
return &fakeCoordinatorClient{
ctx: ctx,
t: t,
reqs: make(chan *coordReqCall),
resps: make(chan *coordRespCall),
close: make(chan chan<- error),
}
}
type coordReqCall struct {
req *proto.CoordinateRequest
err chan<- error
}
type coordRespCall struct {
resp chan<- *proto.CoordinateResponse
err chan<- error
}

View File

@ -467,7 +467,8 @@ func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me
_ = conn.Close() _ = conn.Close()
}) })
ctrl := tailnet.NewSingleDestController(logger, conn, peer.ID) ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
ctrl.AddDestination(peer.ID)
coordination := ctrl.New(coord) coordination := ctrl.New(coord)
t.Cleanup(func() { t.Cleanup(func() {
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)