mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
Addresses #14734. This PR wires up `tunnel.go` to a `tailnet.Conn` via the new `/tailnet` endpoint, with all the necessary controllers such that a VPN connection can be started, stopped and inspected via the CoderVPN protocol.
2020 lines
57 KiB
Go
2020 lines
57 KiB
Go
package tailnet_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/yamux"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
"storj.io/drpc"
|
|
"storj.io/drpc/drpcerr"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/key"
|
|
"tailscale.com/util/dnsname"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
var unimplementedError = drpcerr.WithCode(xerrors.New("Unimplemented"), drpcerr.Unimplemented)
|
|
|
|
func TestInMemoryCoordination(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
clientID := uuid.UUID{1}
|
|
agentID := uuid.UUID{2}
|
|
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
|
fConn := &fakeCoordinatee{}
|
|
|
|
reqs := make(chan *proto.CoordinateRequest, 100)
|
|
resps := make(chan *proto.CoordinateResponse, 100)
|
|
auth := tailnet.ClientCoordinateeAuth{AgentID: agentID}
|
|
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), auth).
|
|
Times(1).Return(reqs, resps)
|
|
|
|
ctrl := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
ctrl.AddDestination(agentID)
|
|
uut := ctrl.New(tailnet.NewInMemoryCoordinatorClient(logger, clientID, auth, mCoord))
|
|
defer uut.Close(ctx)
|
|
|
|
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
|
|
|
// Recv loop should be terminated by the server hanging up after Disconnect
|
|
err := testutil.RequireRecvCtx(ctx, t, uut.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|
|
|
|
func TestTunnelSrcCoordController_Mainline(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
clientID := uuid.UUID{1}
|
|
agentID := uuid.UUID{2}
|
|
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
|
fConn := &fakeCoordinatee{}
|
|
|
|
reqs := make(chan *proto.CoordinateRequest, 100)
|
|
resps := make(chan *proto.CoordinateResponse, 100)
|
|
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
|
|
Times(1).Return(reqs, resps)
|
|
|
|
var coord tailnet.Coordinator = mCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger.Named("svc"),
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
|
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
})
|
|
require.NoError(t, err)
|
|
sC, cC := net.Pipe()
|
|
|
|
serveErr := make(chan error, 1)
|
|
go func() {
|
|
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{
|
|
Name: "client",
|
|
ID: clientID,
|
|
Auth: tailnet.ClientCoordinateeAuth{
|
|
AgentID: agentID,
|
|
},
|
|
})
|
|
serveErr <- err
|
|
}()
|
|
|
|
client, err := tailnet.NewDRPCClient(cC, logger)
|
|
require.NoError(t, err)
|
|
protocol, err := client.Coordinate(ctx)
|
|
require.NoError(t, err)
|
|
|
|
ctrl := tailnet.NewTunnelSrcCoordController(logger.Named("coordination"), fConn)
|
|
ctrl.AddDestination(agentID)
|
|
uut := ctrl.New(protocol)
|
|
defer uut.Close(ctx)
|
|
|
|
coordinationTest(ctx, t, uut, fConn, reqs, resps, agentID)
|
|
|
|
// Recv loop should be terminated by the server hanging up after Disconnect
|
|
err = testutil.RequireRecvCtx(ctx, t, uut.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|
|
|
|
func TestTunnelSrcCoordController_AddDestination(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
fConn := &fakeCoordinatee{}
|
|
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
|
|
// GIVEN: client already connected
|
|
client1 := newFakeCoordinatorClient(ctx, t)
|
|
cw1 := uut.New(client1)
|
|
|
|
// WHEN: we add 2 destinations
|
|
dest1 := uuid.UUID{1}
|
|
dest2 := uuid.UUID{2}
|
|
addDone := make(chan struct{})
|
|
go func() {
|
|
defer close(addDone)
|
|
uut.AddDestination(dest1)
|
|
uut.AddDestination(dest2)
|
|
}()
|
|
|
|
// THEN: Controller sends AddTunnel for the destinations
|
|
for i := range 2 {
|
|
b0 := byte(i + 1)
|
|
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
require.Equal(t, b0, call.req.GetAddTunnel().GetId()[0])
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
}
|
|
_ = testutil.RequireRecvCtx(ctx, t, addDone)
|
|
|
|
// THEN: Controller sets destinations on Coordinatee
|
|
require.Contains(t, fConn.tunnelDestinations, dest1)
|
|
require.Contains(t, fConn.tunnelDestinations, dest2)
|
|
|
|
// WHEN: Closed from server side and reconnects
|
|
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
|
|
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
|
|
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
client2 := newFakeCoordinatorClient(ctx, t)
|
|
cws := make(chan tailnet.CloserWaiter)
|
|
go func() {
|
|
cws <- uut.New(client2)
|
|
}()
|
|
|
|
// THEN: should immediately send both destinations
|
|
var dests []byte
|
|
for range 2 {
|
|
call := testutil.RequireRecvCtx(ctx, t, client2.reqs)
|
|
dests = append(dests, call.req.GetAddTunnel().GetId()[0])
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
}
|
|
slices.Sort(dests)
|
|
require.Equal(t, dests, []byte{1, 2})
|
|
|
|
cw2 := testutil.RequireRecvCtx(ctx, t, cws)
|
|
|
|
// close client2
|
|
respCall = testutil.RequireRecvCtx(ctx, t, client2.resps)
|
|
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
|
|
closeCall = testutil.RequireRecvCtx(ctx, t, client2.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|
|
|
|
func TestTunnelSrcCoordController_RemoveDestination(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
fConn := &fakeCoordinatee{}
|
|
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
|
|
// GIVEN: 1 destination
|
|
dest1 := uuid.UUID{1}
|
|
uut.AddDestination(dest1)
|
|
|
|
// GIVEN: client already connected
|
|
client1 := newFakeCoordinatorClient(ctx, t)
|
|
cws := make(chan tailnet.CloserWaiter)
|
|
go func() {
|
|
cws <- uut.New(client1)
|
|
}()
|
|
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
|
|
|
|
// WHEN: we remove one destination
|
|
removeDone := make(chan struct{})
|
|
go func() {
|
|
defer close(removeDone)
|
|
uut.RemoveDestination(dest1)
|
|
}()
|
|
|
|
// THEN: Controller sends RemoveTunnel for the destination
|
|
call = testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
_ = testutil.RequireRecvCtx(ctx, t, removeDone)
|
|
|
|
// WHEN: Closed from server side and reconnect
|
|
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
|
|
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
|
|
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
|
|
client2 := newFakeCoordinatorClient(ctx, t)
|
|
go func() {
|
|
cws <- uut.New(client2)
|
|
}()
|
|
|
|
// THEN: should immediately resolve without sending anything
|
|
cw2 := testutil.RequireRecvCtx(ctx, t, cws)
|
|
|
|
// close client2
|
|
respCall = testutil.RequireRecvCtx(ctx, t, client2.resps)
|
|
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
|
|
closeCall = testutil.RequireRecvCtx(ctx, t, client2.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|
|
|
|
func TestTunnelSrcCoordController_RemoveDestination_Error(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
fConn := &fakeCoordinatee{}
|
|
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
|
|
// GIVEN: 3 destination
|
|
dest1 := uuid.UUID{1}
|
|
dest2 := uuid.UUID{2}
|
|
dest3 := uuid.UUID{3}
|
|
uut.AddDestination(dest1)
|
|
uut.AddDestination(dest2)
|
|
uut.AddDestination(dest3)
|
|
|
|
// GIVEN: client already connected
|
|
client1 := newFakeCoordinatorClient(ctx, t)
|
|
cws := make(chan tailnet.CloserWaiter)
|
|
go func() {
|
|
cws <- uut.New(client1)
|
|
}()
|
|
for range 3 {
|
|
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
}
|
|
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
|
|
|
|
// WHEN: we remove all destinations
|
|
removeDone := make(chan struct{})
|
|
go func() {
|
|
defer close(removeDone)
|
|
uut.RemoveDestination(dest1)
|
|
uut.RemoveDestination(dest2)
|
|
uut.RemoveDestination(dest3)
|
|
}()
|
|
|
|
// WHEN: first RemoveTunnel call fails
|
|
theErr := xerrors.New("a bad thing happened")
|
|
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
|
|
testutil.RequireSendCtx(ctx, t, call.err, theErr)
|
|
|
|
// THEN: we disconnect and do not send remaining RemoveTunnel messages
|
|
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
_ = testutil.RequireRecvCtx(ctx, t, removeDone)
|
|
|
|
// shut down
|
|
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
|
|
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
|
|
// triggers second close call
|
|
closeCall = testutil.RequireRecvCtx(ctx, t, client1.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
|
|
require.ErrorIs(t, err, theErr)
|
|
}
|
|
|
|
func TestTunnelSrcCoordController_Sync(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
fConn := &fakeCoordinatee{}
|
|
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
dest1 := uuid.UUID{1}
|
|
dest2 := uuid.UUID{2}
|
|
dest3 := uuid.UUID{3}
|
|
|
|
// GIVEN: dest1 & dest2 already added
|
|
uut.AddDestination(dest1)
|
|
uut.AddDestination(dest2)
|
|
|
|
// GIVEN: client already connected
|
|
client1 := newFakeCoordinatorClient(ctx, t)
|
|
cws := make(chan tailnet.CloserWaiter)
|
|
go func() {
|
|
cws <- uut.New(client1)
|
|
}()
|
|
for range 2 {
|
|
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
}
|
|
cw1 := testutil.RequireRecvCtx(ctx, t, cws)
|
|
|
|
// WHEN: we sync dest2 & dest3
|
|
syncDone := make(chan struct{})
|
|
go func() {
|
|
defer close(syncDone)
|
|
uut.SyncDestinations([]uuid.UUID{dest2, dest3})
|
|
}()
|
|
|
|
// THEN: we get an add for dest3 and remove for dest1
|
|
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
require.Equal(t, dest3[:], call.req.GetAddTunnel().GetId())
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
call = testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
require.Equal(t, dest1[:], call.req.GetRemoveTunnel().GetId())
|
|
testutil.RequireSendCtx(ctx, t, call.err, nil)
|
|
|
|
testutil.RequireRecvCtx(ctx, t, syncDone)
|
|
// dest3 should be added to coordinatee
|
|
require.Contains(t, fConn.tunnelDestinations, dest3)
|
|
|
|
// shut down
|
|
respCall := testutil.RequireRecvCtx(ctx, t, client1.resps)
|
|
testutil.RequireSendCtx(ctx, t, respCall.err, io.EOF)
|
|
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|
|
|
|
func TestTunnelSrcCoordController_AddDestination_Error(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
fConn := &fakeCoordinatee{}
|
|
uut := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
|
|
// GIVEN: client already connected
|
|
client1 := newFakeCoordinatorClient(ctx, t)
|
|
cw1 := uut.New(client1)
|
|
|
|
// WHEN: we add a destination, and the AddTunnel fails
|
|
dest1 := uuid.UUID{1}
|
|
addDone := make(chan struct{})
|
|
go func() {
|
|
defer close(addDone)
|
|
uut.AddDestination(dest1)
|
|
}()
|
|
theErr := xerrors.New("a bad thing happened")
|
|
call := testutil.RequireRecvCtx(ctx, t, client1.reqs)
|
|
testutil.RequireSendCtx(ctx, t, call.err, theErr)
|
|
|
|
// THEN: Client is closed and exits
|
|
closeCall := testutil.RequireRecvCtx(ctx, t, client1.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
|
|
// close the resps, since the client has closed
|
|
resp := testutil.RequireRecvCtx(ctx, t, client1.resps)
|
|
testutil.RequireSendCtx(ctx, t, resp.err, net.ErrClosed)
|
|
// this triggers a second Close() call on the client
|
|
closeCall = testutil.RequireRecvCtx(ctx, t, client1.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
|
|
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
|
|
require.ErrorIs(t, err, theErr)
|
|
|
|
_ = testutil.RequireRecvCtx(ctx, t, addDone)
|
|
}
|
|
|
|
func TestAgentCoordinationController_SendsReadyForHandshake(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
clientID := uuid.UUID{1}
|
|
agentID := uuid.UUID{2}
|
|
mCoord := tailnettest.NewMockCoordinator(gomock.NewController(t))
|
|
fConn := &fakeCoordinatee{}
|
|
|
|
reqs := make(chan *proto.CoordinateRequest, 100)
|
|
resps := make(chan *proto.CoordinateResponse, 100)
|
|
mCoord.EXPECT().Coordinate(gomock.Any(), clientID, gomock.Any(), tailnet.ClientCoordinateeAuth{agentID}).
|
|
Times(1).Return(reqs, resps)
|
|
|
|
var coord tailnet.Coordinator = mCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger.Named("svc"),
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Hour,
|
|
DERPMapFn: func() *tailcfg.DERPMap { panic("not implemented") },
|
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { panic("not implemented") },
|
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
})
|
|
require.NoError(t, err)
|
|
sC, cC := net.Pipe()
|
|
|
|
serveErr := make(chan error, 1)
|
|
go func() {
|
|
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, tailnet.StreamID{
|
|
Name: "client",
|
|
ID: clientID,
|
|
Auth: tailnet.ClientCoordinateeAuth{
|
|
AgentID: agentID,
|
|
},
|
|
})
|
|
serveErr <- err
|
|
}()
|
|
|
|
client, err := tailnet.NewDRPCClient(cC, logger)
|
|
require.NoError(t, err)
|
|
protocol, err := client.Coordinate(ctx)
|
|
require.NoError(t, err)
|
|
|
|
ctrl := tailnet.NewAgentCoordinationController(logger.Named("coordination"), fConn)
|
|
uut := ctrl.New(protocol)
|
|
defer uut.Close(ctx)
|
|
|
|
nk, err := key.NewNode().Public().MarshalBinary()
|
|
require.NoError(t, err)
|
|
dk, err := key.NewDisco().Public().MarshalText()
|
|
require.NoError(t, err)
|
|
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{
|
|
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{{
|
|
Id: clientID[:],
|
|
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
|
Node: &proto.Node{
|
|
Id: 3,
|
|
Key: nk,
|
|
Disco: string(dk),
|
|
},
|
|
}},
|
|
})
|
|
|
|
rfh := testutil.RequireRecvCtx(ctx, t, reqs)
|
|
require.NotNil(t, rfh.ReadyForHandshake)
|
|
require.Len(t, rfh.ReadyForHandshake, 1)
|
|
require.Equal(t, clientID[:], rfh.ReadyForHandshake[0].Id)
|
|
|
|
go uut.Close(ctx)
|
|
dis := testutil.RequireRecvCtx(ctx, t, reqs)
|
|
require.NotNil(t, dis)
|
|
require.NotNil(t, dis.Disconnect)
|
|
close(resps)
|
|
|
|
// Recv loop should be terminated by the server hanging up after Disconnect
|
|
err = testutil.RequireRecvCtx(ctx, t, uut.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
}
|
|
|
|
// coordinationTest tests that a coordination behaves correctly
|
|
func coordinationTest(
|
|
ctx context.Context, t *testing.T,
|
|
uut tailnet.CloserWaiter, fConn *fakeCoordinatee,
|
|
reqs chan *proto.CoordinateRequest, resps chan *proto.CoordinateResponse,
|
|
agentID uuid.UUID,
|
|
) {
|
|
// It should add the tunnel, since we configured as a client
|
|
req := testutil.RequireRecvCtx(ctx, t, reqs)
|
|
require.Equal(t, agentID[:], req.GetAddTunnel().GetId())
|
|
|
|
// when we call the callback, it should send a node update
|
|
require.NotNil(t, fConn.callback)
|
|
fConn.callback(&tailnet.Node{PreferredDERP: 1})
|
|
|
|
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
|
require.Equal(t, int32(1), req.GetUpdateSelf().GetNode().GetPreferredDerp())
|
|
|
|
// When we send a peer update, it should update the coordinatee
|
|
nk, err := key.NewNode().Public().MarshalBinary()
|
|
require.NoError(t, err)
|
|
dk, err := key.NewDisco().Public().MarshalText()
|
|
require.NoError(t, err)
|
|
updates := []*proto.CoordinateResponse_PeerUpdate{
|
|
{
|
|
Id: agentID[:],
|
|
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
|
Node: &proto.Node{
|
|
Id: 2,
|
|
Key: nk,
|
|
Disco: string(dk),
|
|
},
|
|
},
|
|
}
|
|
testutil.RequireSendCtx(ctx, t, resps, &proto.CoordinateResponse{PeerUpdates: updates})
|
|
require.Eventually(t, func() bool {
|
|
fConn.Lock()
|
|
defer fConn.Unlock()
|
|
return len(fConn.updates) > 0
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
require.Len(t, fConn.updates[0], 1)
|
|
require.Equal(t, agentID[:], fConn.updates[0][0].Id)
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- uut.Close(ctx)
|
|
}()
|
|
|
|
// When we close, it should gracefully disconnect
|
|
req = testutil.RequireRecvCtx(ctx, t, reqs)
|
|
require.NotNil(t, req.Disconnect)
|
|
close(resps)
|
|
|
|
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
|
require.NoError(t, err)
|
|
|
|
// It should set all peers lost on the coordinatee
|
|
require.Equal(t, 1, fConn.setAllPeersLostCalls)
|
|
}
|
|
|
|
type fakeCoordinatee struct {
|
|
sync.Mutex
|
|
callback func(*tailnet.Node)
|
|
updates [][]*proto.CoordinateResponse_PeerUpdate
|
|
setAllPeersLostCalls int
|
|
tunnelDestinations map[uuid.UUID]struct{}
|
|
}
|
|
|
|
func (f *fakeCoordinatee) UpdatePeers(updates []*proto.CoordinateResponse_PeerUpdate) error {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
f.updates = append(f.updates, updates)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeCoordinatee) SetAllPeersLost() {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
f.setAllPeersLostCalls++
|
|
}
|
|
|
|
func (f *fakeCoordinatee) SetTunnelDestination(id uuid.UUID) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
|
|
if f.tunnelDestinations == nil {
|
|
f.tunnelDestinations = map[uuid.UUID]struct{}{}
|
|
}
|
|
f.tunnelDestinations[id] = struct{}{}
|
|
}
|
|
|
|
func (f *fakeCoordinatee) SetNodeCallback(callback func(*tailnet.Node)) {
|
|
f.Lock()
|
|
defer f.Unlock()
|
|
f.callback = callback
|
|
}
|
|
|
|
func TestNewBasicDERPController_Mainline(t *testing.T) {
|
|
t.Parallel()
|
|
fs := make(chan *tailcfg.DERPMap)
|
|
logger := testutil.Logger(t)
|
|
uut := tailnet.NewBasicDERPController(logger, fakeSetter(fs))
|
|
fc := fakeDERPClient{
|
|
ch: make(chan *tailcfg.DERPMap),
|
|
}
|
|
c := uut.New(fc)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
expectDM := &tailcfg.DERPMap{}
|
|
testutil.RequireSendCtx(ctx, t, fc.ch, expectDM)
|
|
gotDM := testutil.RequireRecvCtx(ctx, t, fs)
|
|
require.Equal(t, expectDM, gotDM)
|
|
err := c.Close(ctx)
|
|
require.NoError(t, err)
|
|
err = testutil.RequireRecvCtx(ctx, t, c.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
// ensure Close is idempotent
|
|
err = c.Close(ctx)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestNewBasicDERPController_RecvErr(t *testing.T) {
|
|
t.Parallel()
|
|
fs := make(chan *tailcfg.DERPMap)
|
|
logger := testutil.Logger(t)
|
|
uut := tailnet.NewBasicDERPController(logger, fakeSetter(fs))
|
|
expectedErr := xerrors.New("a bad thing happened")
|
|
fc := fakeDERPClient{
|
|
ch: make(chan *tailcfg.DERPMap),
|
|
err: expectedErr,
|
|
}
|
|
c := uut.New(fc)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
err := testutil.RequireRecvCtx(ctx, t, c.Wait())
|
|
require.ErrorIs(t, err, expectedErr)
|
|
// ensure Close is idempotent
|
|
err = c.Close(ctx)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type fakeSetter chan *tailcfg.DERPMap
|
|
|
|
func (s fakeSetter) SetDERPMap(derpMap *tailcfg.DERPMap) {
|
|
s <- derpMap
|
|
}
|
|
|
|
type fakeDERPClient struct {
|
|
ch chan *tailcfg.DERPMap
|
|
err error
|
|
}
|
|
|
|
func (f fakeDERPClient) Close() error {
|
|
close(f.ch)
|
|
return nil
|
|
}
|
|
|
|
func (f fakeDERPClient) Recv() (*tailcfg.DERPMap, error) {
|
|
if f.err != nil {
|
|
return nil, f.err
|
|
}
|
|
dm, ok := <-f.ch
|
|
if ok {
|
|
return dm, nil
|
|
}
|
|
return nil, io.EOF
|
|
}
|
|
|
|
func TestBasicTelemetryController_Success(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
uut := tailnet.NewBasicTelemetryController(logger)
|
|
ft := newFakeTelemetryClient()
|
|
uut.New(ft)
|
|
|
|
sendDone := make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{
|
|
Id: []byte("test event"),
|
|
})
|
|
}()
|
|
|
|
call := testutil.RequireRecvCtx(ctx, t, ft.calls)
|
|
require.Len(t, call.req.GetEvents(), 1)
|
|
require.Equal(t, call.req.GetEvents()[0].GetId(), []byte("test event"))
|
|
|
|
testutil.RequireSendCtx(ctx, t, call.errCh, nil)
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
}
|
|
|
|
func TestBasicTelemetryController_Unimplemented(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
ft := newFakeTelemetryClient()
|
|
|
|
uut := tailnet.NewBasicTelemetryController(logger)
|
|
uut.New(ft)
|
|
|
|
// bad code, doesn't count
|
|
telemetryError := drpcerr.WithCode(xerrors.New("Unimplemented"), 0)
|
|
|
|
sendDone := make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
}()
|
|
|
|
call := testutil.RequireRecvCtx(ctx, t, ft.calls)
|
|
testutil.RequireSendCtx(ctx, t, call.errCh, telemetryError)
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
|
|
sendDone = make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
}()
|
|
|
|
// we get another call since it wasn't really the Unimplemented error
|
|
call = testutil.RequireRecvCtx(ctx, t, ft.calls)
|
|
|
|
// for real this time
|
|
telemetryError = unimplementedError
|
|
testutil.RequireSendCtx(ctx, t, call.errCh, telemetryError)
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
|
|
// now this returns immediately without a call, because unimplemented error disables calling
|
|
sendDone = make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
}()
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
|
|
// getting a "new" client resets
|
|
uut.New(ft)
|
|
sendDone = make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
}()
|
|
call = testutil.RequireRecvCtx(ctx, t, ft.calls)
|
|
testutil.RequireSendCtx(ctx, t, call.errCh, nil)
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
}
|
|
|
|
func TestBasicTelemetryController_NotRecognised(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
ft := newFakeTelemetryClient()
|
|
uut := tailnet.NewBasicTelemetryController(logger)
|
|
uut.New(ft)
|
|
|
|
sendDone := make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
}()
|
|
// returning generic protocol error doesn't trigger unknown rpc logic
|
|
call := testutil.RequireRecvCtx(ctx, t, ft.calls)
|
|
testutil.RequireSendCtx(ctx, t, call.errCh, drpc.ProtocolError.New("Protocol Error"))
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
|
|
sendDone = make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
}()
|
|
call = testutil.RequireRecvCtx(ctx, t, ft.calls)
|
|
// return the expected protocol error this time
|
|
testutil.RequireSendCtx(ctx, t, call.errCh,
|
|
drpc.ProtocolError.New("unknown rpc: /coder.tailnet.v2.Tailnet/PostTelemetry"))
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
|
|
// now this returns immediately without a call, because unimplemented error disables calling
|
|
sendDone = make(chan struct{})
|
|
go func() {
|
|
defer close(sendDone)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
}()
|
|
testutil.RequireRecvCtx(ctx, t, sendDone)
|
|
}
|
|
|
|
type fakeTelemetryClient struct {
|
|
calls chan *fakeTelemetryCall
|
|
}
|
|
|
|
var _ tailnet.TelemetryClient = &fakeTelemetryClient{}
|
|
|
|
func newFakeTelemetryClient() *fakeTelemetryClient {
|
|
return &fakeTelemetryClient{
|
|
calls: make(chan *fakeTelemetryCall),
|
|
}
|
|
}
|
|
|
|
// PostTelemetry implements tailnet.TelemetryClient
|
|
func (f *fakeTelemetryClient) PostTelemetry(ctx context.Context, req *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
|
|
fr := &fakeTelemetryCall{req: req, errCh: make(chan error)}
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case f.calls <- fr:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case err := <-fr.errCh:
|
|
return &proto.TelemetryResponse{}, err
|
|
}
|
|
}
|
|
|
|
type fakeTelemetryCall struct {
|
|
req *proto.TelemetryRequest
|
|
errCh chan error
|
|
}
|
|
|
|
func TestBasicResumeTokenController_Mainline(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
fr := newFakeResumeTokenClient(ctx)
|
|
mClock := quartz.NewMock(t)
|
|
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
|
|
defer trp.Close()
|
|
|
|
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
|
|
_, ok := uut.Token()
|
|
require.False(t, ok)
|
|
|
|
cwCh := make(chan tailnet.CloserWaiter, 1)
|
|
go func() {
|
|
cwCh <- uut.New(fr)
|
|
}()
|
|
call := testutil.RequireRecvCtx(ctx, t, fr.calls)
|
|
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
|
|
Token: "test token 1",
|
|
RefreshIn: durationpb.New(100 * time.Second),
|
|
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
|
|
})
|
|
trp.MustWait(ctx).Release() // initial refresh done
|
|
token, ok := uut.Token()
|
|
require.True(t, ok)
|
|
require.Equal(t, "test token 1", token)
|
|
cw := testutil.RequireRecvCtx(ctx, t, cwCh)
|
|
|
|
w := mClock.Advance(100 * time.Second)
|
|
call = testutil.RequireRecvCtx(ctx, t, fr.calls)
|
|
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
|
|
Token: "test token 2",
|
|
RefreshIn: durationpb.New(50 * time.Second),
|
|
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
|
|
})
|
|
resetCall := trp.MustWait(ctx)
|
|
require.Equal(t, resetCall.Duration, 50*time.Second)
|
|
resetCall.Release()
|
|
w.MustWait(ctx)
|
|
token, ok = uut.Token()
|
|
require.True(t, ok)
|
|
require.Equal(t, "test token 2", token)
|
|
|
|
err := cw.Close(ctx)
|
|
require.NoError(t, err)
|
|
err = testutil.RequireRecvCtx(ctx, t, cw.Wait())
|
|
require.NoError(t, err)
|
|
|
|
token, ok = uut.Token()
|
|
require.True(t, ok)
|
|
require.Equal(t, "test token 2", token)
|
|
|
|
mClock.Advance(201 * time.Second).MustWait(ctx)
|
|
_, ok = uut.Token()
|
|
require.False(t, ok)
|
|
}
|
|
|
|
func TestBasicResumeTokenController_NewWhileRefreshing(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
mClock := quartz.NewMock(t)
|
|
trp := mClock.Trap().TimerReset("basicResumeTokenRefresher", "refresh")
|
|
defer trp.Close()
|
|
|
|
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
|
|
_, ok := uut.Token()
|
|
require.False(t, ok)
|
|
|
|
fr1 := newFakeResumeTokenClient(ctx)
|
|
cwCh1 := make(chan tailnet.CloserWaiter, 1)
|
|
go func() {
|
|
cwCh1 <- uut.New(fr1)
|
|
}()
|
|
call1 := testutil.RequireRecvCtx(ctx, t, fr1.calls)
|
|
|
|
fr2 := newFakeResumeTokenClient(ctx)
|
|
cwCh2 := make(chan tailnet.CloserWaiter, 1)
|
|
go func() {
|
|
cwCh2 <- uut.New(fr2)
|
|
}()
|
|
call2 := testutil.RequireRecvCtx(ctx, t, fr2.calls)
|
|
|
|
testutil.RequireSendCtx(ctx, t, call2.resp, &proto.RefreshResumeTokenResponse{
|
|
Token: "test token 2.0",
|
|
RefreshIn: durationpb.New(102 * time.Second),
|
|
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
|
|
})
|
|
|
|
cw2 := testutil.RequireRecvCtx(ctx, t, cwCh2) // this ensures Close was called on 1
|
|
|
|
testutil.RequireSendCtx(ctx, t, call1.resp, &proto.RefreshResumeTokenResponse{
|
|
Token: "test token 1",
|
|
RefreshIn: durationpb.New(101 * time.Second),
|
|
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
|
|
})
|
|
|
|
trp.MustWait(ctx).Release()
|
|
|
|
token, ok := uut.Token()
|
|
require.True(t, ok)
|
|
require.Equal(t, "test token 2.0", token)
|
|
|
|
// refresher 1 should already be closed.
|
|
cw1 := testutil.RequireRecvCtx(ctx, t, cwCh1)
|
|
err := testutil.RequireRecvCtx(ctx, t, cw1.Wait())
|
|
require.NoError(t, err)
|
|
|
|
w := mClock.Advance(102 * time.Second)
|
|
call := testutil.RequireRecvCtx(ctx, t, fr2.calls)
|
|
testutil.RequireSendCtx(ctx, t, call.resp, &proto.RefreshResumeTokenResponse{
|
|
Token: "test token 2.1",
|
|
RefreshIn: durationpb.New(50 * time.Second),
|
|
ExpiresAt: timestamppb.New(mClock.Now().Add(200 * time.Second)),
|
|
})
|
|
resetCall := trp.MustWait(ctx)
|
|
require.Equal(t, resetCall.Duration, 50*time.Second)
|
|
resetCall.Release()
|
|
w.MustWait(ctx)
|
|
token, ok = uut.Token()
|
|
require.True(t, ok)
|
|
require.Equal(t, "test token 2.1", token)
|
|
|
|
err = cw2.Close(ctx)
|
|
require.NoError(t, err)
|
|
err = testutil.RequireRecvCtx(ctx, t, cw2.Wait())
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestBasicResumeTokenController_Unimplemented(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
mClock := quartz.NewMock(t)
|
|
|
|
uut := tailnet.NewBasicResumeTokenController(logger, mClock)
|
|
_, ok := uut.Token()
|
|
require.False(t, ok)
|
|
|
|
fr := newFakeResumeTokenClient(ctx)
|
|
cw := uut.New(fr)
|
|
|
|
call := testutil.RequireRecvCtx(ctx, t, fr.calls)
|
|
testutil.RequireSendCtx(ctx, t, call.errCh, unimplementedError)
|
|
err := testutil.RequireRecvCtx(ctx, t, cw.Wait())
|
|
require.NoError(t, err)
|
|
_, ok = uut.Token()
|
|
require.False(t, ok)
|
|
}
|
|
|
|
func newFakeResumeTokenClient(ctx context.Context) *fakeResumeTokenClient {
|
|
return &fakeResumeTokenClient{
|
|
ctx: ctx,
|
|
calls: make(chan *fakeResumeTokenCall),
|
|
}
|
|
}
|
|
|
|
type fakeResumeTokenClient struct {
|
|
ctx context.Context
|
|
calls chan *fakeResumeTokenCall
|
|
}
|
|
|
|
func (f *fakeResumeTokenClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
|
|
call := &fakeResumeTokenCall{
|
|
resp: make(chan *proto.RefreshResumeTokenResponse),
|
|
errCh: make(chan error),
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return nil, timeoutOnFakeErr
|
|
case f.calls <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return nil, timeoutOnFakeErr
|
|
case err := <-call.errCh:
|
|
return nil, err
|
|
case resp := <-call.resp:
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
type fakeResumeTokenCall struct {
|
|
resp chan *proto.RefreshResumeTokenResponse
|
|
errCh chan error
|
|
}
|
|
|
|
func TestController_Disconnects(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := testutil.Context(t, testutil.WaitShort)
|
|
ctx, cancel := context.WithCancel(testCtx)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs,
|
|
io.EOF, // we get EOF when we simulate a DERPMap error
|
|
yamux.ErrSessionShutdown, // coordination can throw these when DERP error tears down session
|
|
),
|
|
}).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
clientID := uuid.UUID{0x66}
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
defer close(derpMapCh)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger.Named("svc"),
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Millisecond,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
|
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialer := &pipeDialer{
|
|
ctx: testCtx,
|
|
logger: logger,
|
|
t: t,
|
|
svc: svc,
|
|
streamID: tailnet.StreamID{
|
|
Name: "client",
|
|
ID: clientID,
|
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
},
|
|
}
|
|
|
|
peersLost := make(chan struct{})
|
|
fConn := &fakeTailnetConn{peersLostCh: peersLost}
|
|
|
|
uut := tailnet.NewController(logger.Named("ctrl"), dialer,
|
|
// darwin can be slow sometimes.
|
|
tailnet.WithGracefulTimeout(5*time.Second))
|
|
uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger.Named("coord_ctrl"), fConn)
|
|
uut.DERPCtrl = tailnet.NewBasicDERPController(logger.Named("derp_ctrl"), fConn)
|
|
uut.Run(ctx)
|
|
|
|
call := testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
|
|
|
|
// simulate a problem with DERPMaps by sending nil
|
|
testutil.RequireSendCtx(testCtx, t, derpMapCh, nil)
|
|
|
|
// this should cause the coordinate call to hang up WITHOUT disconnecting
|
|
reqNil := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
|
require.Nil(t, reqNil)
|
|
|
|
// and mark all peers lost
|
|
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
|
|
|
|
// ...and then reconnect
|
|
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
|
|
|
|
// close the coordination call, which should cause a 2nd reconnection
|
|
close(call.Resps)
|
|
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
|
|
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
|
|
|
|
// canceling the context should trigger the disconnect message
|
|
cancel()
|
|
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
|
require.NotNil(t, reqDisc)
|
|
require.NotNil(t, reqDisc.Disconnect)
|
|
close(call.Resps)
|
|
|
|
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
|
|
_ = testutil.RequireRecvCtx(testCtx, t, uut.Closed())
|
|
}
|
|
|
|
func TestController_TelemetrySuccess(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
agentID := uuid.UUID{0x55}
|
|
clientID := uuid.UUID{0x66}
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
defer close(derpMapCh)
|
|
eventCh := make(chan []*proto.TelemetryEvent, 1)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Millisecond,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Error("timeout sending telemetry event")
|
|
case eventCh <- batch:
|
|
t.Log("sent telemetry batch")
|
|
}
|
|
},
|
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
dialer := &pipeDialer{
|
|
ctx: ctx,
|
|
logger: logger,
|
|
t: t,
|
|
svc: svc,
|
|
streamID: tailnet.StreamID{
|
|
Name: "client",
|
|
ID: clientID,
|
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
},
|
|
}
|
|
|
|
uut := tailnet.NewController(logger, dialer)
|
|
uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger, &fakeTailnetConn{})
|
|
tel := tailnet.NewBasicTelemetryController(logger)
|
|
uut.TelemetryCtrl = tel
|
|
uut.Run(ctx)
|
|
// Coordinate calls happen _after_ telemetry is connected up, so we use this
|
|
// to ensure telemetry is connected before sending our event
|
|
cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
defer close(cc.Resps)
|
|
|
|
tel.SendTelemetryEvent(&proto.TelemetryEvent{
|
|
Id: []byte("test event"),
|
|
})
|
|
|
|
testEvents := testutil.RequireRecvCtx(ctx, t, eventCh)
|
|
|
|
require.Len(t, testEvents, 1)
|
|
require.Equal(t, []byte("test event"), testEvents[0].Id)
|
|
}
|
|
|
|
func TestController_WorkspaceUpdates(t *testing.T) {
|
|
t.Parallel()
|
|
theError := xerrors.New("a bad thing happened")
|
|
testCtx := testutil.Context(t, testutil.WaitShort)
|
|
ctx, cancel := context.WithCancel(testCtx)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, theError),
|
|
}).Leveled(slog.LevelDebug)
|
|
|
|
fClient := newFakeWorkspaceUpdateClient(testCtx, t)
|
|
dialer := &fakeWorkspaceUpdatesDialer{
|
|
client: fClient,
|
|
}
|
|
|
|
uut := tailnet.NewController(logger.Named("ctrl"), dialer)
|
|
fCtrl := newFakeUpdatesController(ctx, t)
|
|
uut.WorkspaceUpdatesCtrl = fCtrl
|
|
uut.Run(ctx)
|
|
|
|
// it should dial and pass the client to the controller
|
|
call := testutil.RequireRecvCtx(testCtx, t, fCtrl.calls)
|
|
require.Equal(t, fClient, call.client)
|
|
fCW := newFakeCloserWaiter()
|
|
testutil.RequireSendCtx[tailnet.CloserWaiter](testCtx, t, call.resp, fCW)
|
|
|
|
// if the CloserWaiter exits...
|
|
testutil.RequireSendCtx(testCtx, t, fCW.errCh, theError)
|
|
|
|
// it should close, redial and reconnect
|
|
cCall := testutil.RequireRecvCtx(testCtx, t, fClient.close)
|
|
testutil.RequireSendCtx(testCtx, t, cCall, nil)
|
|
|
|
call = testutil.RequireRecvCtx(testCtx, t, fCtrl.calls)
|
|
require.Equal(t, fClient, call.client)
|
|
fCW = newFakeCloserWaiter()
|
|
testutil.RequireSendCtx[tailnet.CloserWaiter](testCtx, t, call.resp, fCW)
|
|
|
|
// canceling the context should close the client
|
|
cancel()
|
|
cCall = testutil.RequireRecvCtx(testCtx, t, fClient.close)
|
|
testutil.RequireSendCtx(testCtx, t, cCall, nil)
|
|
}
|
|
|
|
type fakeTailnetConn struct {
|
|
peersLostCh chan struct{}
|
|
}
|
|
|
|
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
|
// TODO implement me
|
|
panic("implement me")
|
|
}
|
|
|
|
func (f *fakeTailnetConn) SetAllPeersLost() {
|
|
if f.peersLostCh == nil {
|
|
return
|
|
}
|
|
f.peersLostCh <- struct{}{}
|
|
}
|
|
|
|
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
|
|
|
|
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
|
|
|
|
func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {}
|
|
|
|
type pipeDialer struct {
|
|
ctx context.Context
|
|
logger slog.Logger
|
|
t testing.TB
|
|
svc *tailnet.ClientService
|
|
streamID tailnet.StreamID
|
|
}
|
|
|
|
func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
|
s, c := net.Pipe()
|
|
p.t.Cleanup(func() {
|
|
_ = s.Close()
|
|
_ = c.Close()
|
|
})
|
|
go func() {
|
|
err := p.svc.ServeConnV2(p.ctx, s, p.streamID)
|
|
p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err))
|
|
}()
|
|
client, err := tailnet.NewDRPCClient(c, p.logger)
|
|
if err != nil {
|
|
_ = c.Close()
|
|
return tailnet.ControlProtocolClients{}, err
|
|
}
|
|
coord, err := client.Coordinate(context.Background())
|
|
if err != nil {
|
|
_ = c.Close()
|
|
return tailnet.ControlProtocolClients{}, err
|
|
}
|
|
|
|
derps := &tailnet.DERPFromDRPCWrapper{}
|
|
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
|
if err != nil {
|
|
_ = c.Close()
|
|
return tailnet.ControlProtocolClients{}, err
|
|
}
|
|
return tailnet.ControlProtocolClients{
|
|
Closer: client.DRPCConn(),
|
|
Coordinator: coord,
|
|
DERP: derps,
|
|
ResumeToken: client,
|
|
Telemetry: client,
|
|
}, nil
|
|
}
|
|
|
|
// timeoutOnFakeErr is the error we send when fakes fail to send calls or receive responses before
|
|
// their context times out. We don't want to send the context error since that often doesn't trigger
|
|
// test failures or logging.
|
|
var timeoutOnFakeErr = xerrors.New("test timeout")
|
|
|
|
type fakeCoordinatorClient struct {
|
|
ctx context.Context
|
|
t testing.TB
|
|
reqs chan *coordReqCall
|
|
resps chan *coordRespCall
|
|
close chan chan<- error
|
|
}
|
|
|
|
func (f fakeCoordinatorClient) Close() error {
|
|
f.t.Helper()
|
|
errs := make(chan error)
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case f.close <- errs:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case err := <-errs:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (f fakeCoordinatorClient) Send(request *proto.CoordinateRequest) error {
|
|
f.t.Helper()
|
|
errs := make(chan error)
|
|
call := &coordReqCall{
|
|
req: request,
|
|
err: errs,
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case f.reqs <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case err := <-errs:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (f fakeCoordinatorClient) Recv() (*proto.CoordinateResponse, error) {
|
|
f.t.Helper()
|
|
resps := make(chan *proto.CoordinateResponse)
|
|
errs := make(chan error)
|
|
call := &coordRespCall{
|
|
resp: resps,
|
|
err: errs,
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return nil, timeoutOnFakeErr
|
|
case f.resps <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return nil, timeoutOnFakeErr
|
|
case err := <-errs:
|
|
return nil, err
|
|
case resp := <-resps:
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
func newFakeCoordinatorClient(ctx context.Context, t testing.TB) *fakeCoordinatorClient {
|
|
return &fakeCoordinatorClient{
|
|
ctx: ctx,
|
|
t: t,
|
|
reqs: make(chan *coordReqCall),
|
|
resps: make(chan *coordRespCall),
|
|
close: make(chan chan<- error),
|
|
}
|
|
}
|
|
|
|
type coordReqCall struct {
|
|
req *proto.CoordinateRequest
|
|
err chan<- error
|
|
}
|
|
|
|
type coordRespCall struct {
|
|
resp chan<- *proto.CoordinateResponse
|
|
err chan<- error
|
|
}
|
|
|
|
type fakeWorkspaceUpdateClient struct {
|
|
ctx context.Context
|
|
t testing.TB
|
|
recv chan *updateRecvCall
|
|
close chan chan<- error
|
|
}
|
|
|
|
func (f *fakeWorkspaceUpdateClient) Close() error {
|
|
f.t.Helper()
|
|
errs := make(chan error)
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case f.close <- errs:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case err := <-errs:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (f *fakeWorkspaceUpdateClient) Recv() (*proto.WorkspaceUpdate, error) {
|
|
f.t.Helper()
|
|
resps := make(chan *proto.WorkspaceUpdate)
|
|
errs := make(chan error)
|
|
call := &updateRecvCall{
|
|
resp: resps,
|
|
err: errs,
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return nil, timeoutOnFakeErr
|
|
case f.recv <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return nil, timeoutOnFakeErr
|
|
case err := <-errs:
|
|
return nil, err
|
|
case resp := <-resps:
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
func newFakeWorkspaceUpdateClient(ctx context.Context, t testing.TB) *fakeWorkspaceUpdateClient {
|
|
return &fakeWorkspaceUpdateClient{
|
|
ctx: ctx,
|
|
t: t,
|
|
recv: make(chan *updateRecvCall),
|
|
close: make(chan chan<- error),
|
|
}
|
|
}
|
|
|
|
type updateRecvCall struct {
|
|
resp chan<- *proto.WorkspaceUpdate
|
|
err chan<- error
|
|
}
|
|
|
|
// testUUID returns a UUID with bytes set as b, but shifted 6 bytes so that service prefixes don't
|
|
// overwrite them.
|
|
func testUUID(b ...byte) uuid.UUID {
|
|
o := uuid.UUID{}
|
|
for i := range b {
|
|
o[i+6] = b[i]
|
|
}
|
|
return o
|
|
}
|
|
|
|
type fakeDNSSetter struct {
|
|
ctx context.Context
|
|
t testing.TB
|
|
calls chan *setDNSCall
|
|
}
|
|
|
|
type setDNSCall struct {
|
|
hosts map[dnsname.FQDN][]netip.Addr
|
|
err chan<- error
|
|
}
|
|
|
|
func newFakeDNSSetter(ctx context.Context, t testing.TB) *fakeDNSSetter {
|
|
return &fakeDNSSetter{
|
|
ctx: ctx,
|
|
t: t,
|
|
calls: make(chan *setDNSCall),
|
|
}
|
|
}
|
|
|
|
func (f *fakeDNSSetter) SetDNSHosts(hosts map[dnsname.FQDN][]netip.Addr) error {
|
|
f.t.Helper()
|
|
errs := make(chan error)
|
|
call := &setDNSCall{
|
|
hosts: hosts,
|
|
err: errs,
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case f.calls <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case err := <-errs:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func newFakeUpdateHandler(ctx context.Context, t testing.TB) *fakeUpdateHandler {
|
|
return &fakeUpdateHandler{
|
|
ctx: ctx,
|
|
t: t,
|
|
ch: make(chan tailnet.WorkspaceUpdate),
|
|
}
|
|
}
|
|
|
|
type fakeUpdateHandler struct {
|
|
ctx context.Context
|
|
t testing.TB
|
|
ch chan tailnet.WorkspaceUpdate
|
|
}
|
|
|
|
func (f *fakeUpdateHandler) Update(wu tailnet.WorkspaceUpdate) error {
|
|
f.t.Helper()
|
|
select {
|
|
case <-f.ctx.Done():
|
|
return timeoutOnFakeErr
|
|
case f.ch <- wu:
|
|
// OK
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func setupConnectedAllWorkspaceUpdatesController(
|
|
ctx context.Context, t testing.TB, logger slog.Logger, opts ...tailnet.TunnelAllOption,
|
|
) (
|
|
*fakeCoordinatorClient, *fakeWorkspaceUpdateClient, *tailnet.TunnelAllWorkspaceUpdatesController,
|
|
) {
|
|
fConn := &fakeCoordinatee{}
|
|
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, opts...)
|
|
|
|
// connect up a coordinator client, to track adding and removing tunnels
|
|
coordC := newFakeCoordinatorClient(ctx, t)
|
|
coordCW := tsc.New(coordC)
|
|
t.Cleanup(func() {
|
|
// hang up coord client
|
|
coordRecv := testutil.RequireRecvCtx(ctx, t, coordC.resps)
|
|
testutil.RequireSendCtx(ctx, t, coordRecv.err, io.EOF)
|
|
// sends close on client
|
|
cCall := testutil.RequireRecvCtx(ctx, t, coordC.close)
|
|
testutil.RequireSendCtx(ctx, t, cCall, nil)
|
|
err := testutil.RequireRecvCtx(ctx, t, coordCW.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
})
|
|
|
|
// connect up the updates client
|
|
updateC := newFakeWorkspaceUpdateClient(ctx, t)
|
|
updateCW := uut.New(updateC)
|
|
t.Cleanup(func() {
|
|
// hang up WorkspaceUpdates client
|
|
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
|
|
testutil.RequireSendCtx(ctx, t, upRecvCall.err, io.EOF)
|
|
err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
})
|
|
return coordC, updateC, uut
|
|
}
|
|
|
|
func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
fUH := newFakeUpdateHandler(ctx, t)
|
|
fDNS := newFakeDNSSetter(ctx, t)
|
|
coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
|
|
tailnet.WithDNS(fDNS, "testy"),
|
|
tailnet.WithHandler(fUH),
|
|
)
|
|
|
|
// Initial update contains 2 workspaces with 1 & 2 agents, respectively
|
|
w1ID := testUUID(1)
|
|
w2ID := testUUID(2)
|
|
w1a1ID := testUUID(1, 1)
|
|
w2a1ID := testUUID(2, 1)
|
|
w2a2ID := testUUID(2, 2)
|
|
initUp := &proto.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*proto.Workspace{
|
|
{Id: w1ID[:], Name: "w1"},
|
|
{Id: w2ID[:], Name: "w2"},
|
|
},
|
|
UpsertedAgents: []*proto.Agent{
|
|
{Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]},
|
|
{Id: w2a1ID[:], Name: "w2a1", WorkspaceId: w2ID[:]},
|
|
{Id: w2a2ID[:], Name: "w2a2", WorkspaceId: w2ID[:]},
|
|
},
|
|
}
|
|
|
|
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
|
|
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp)
|
|
|
|
// This should trigger AddTunnel for each agent
|
|
var adds []uuid.UUID
|
|
for range 3 {
|
|
coordCall := testutil.RequireRecvCtx(ctx, t, coordC.reqs)
|
|
adds = append(adds, uuid.Must(uuid.FromBytes(coordCall.req.GetAddTunnel().GetId())))
|
|
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
|
|
}
|
|
require.Contains(t, adds, w1a1ID)
|
|
require.Contains(t, adds, w2a1ID)
|
|
require.Contains(t, adds, w2a2ID)
|
|
|
|
ws1a1IP := netip.MustParseAddr("fd60:627a:a42b:0101::")
|
|
w2a1IP := netip.MustParseAddr("fd60:627a:a42b:0201::")
|
|
w2a2IP := netip.MustParseAddr("fd60:627a:a42b:0202::")
|
|
|
|
// Also triggers setting DNS hosts
|
|
expectedDNS := map[dnsname.FQDN][]netip.Addr{
|
|
"w1a1.w1.me.coder.": {ws1a1IP},
|
|
"w2a1.w2.me.coder.": {w2a1IP},
|
|
"w2a2.w2.me.coder.": {w2a2IP},
|
|
"w1a1.w1.testy.coder.": {ws1a1IP},
|
|
"w2a1.w2.testy.coder.": {w2a1IP},
|
|
"w2a2.w2.testy.coder.": {w2a2IP},
|
|
"w1.coder.": {ws1a1IP},
|
|
}
|
|
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
|
|
require.Equal(t, expectedDNS, dnsCall.hosts)
|
|
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
|
|
|
|
currentState := tailnet.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*tailnet.Workspace{
|
|
{ID: w1ID, Name: "w1"},
|
|
{ID: w2ID, Name: "w2"},
|
|
},
|
|
UpsertedAgents: []*tailnet.Agent{
|
|
{
|
|
ID: w1a1ID, Name: "w1a1", WorkspaceID: w1ID,
|
|
Hosts: map[dnsname.FQDN][]netip.Addr{
|
|
"w1.coder.": {ws1a1IP},
|
|
"w1a1.w1.me.coder.": {ws1a1IP},
|
|
"w1a1.w1.testy.coder.": {ws1a1IP},
|
|
},
|
|
},
|
|
{
|
|
ID: w2a1ID, Name: "w2a1", WorkspaceID: w2ID,
|
|
Hosts: map[dnsname.FQDN][]netip.Addr{
|
|
"w2a1.w2.me.coder.": {w2a1IP},
|
|
"w2a1.w2.testy.coder.": {w2a1IP},
|
|
},
|
|
},
|
|
{
|
|
ID: w2a2ID, Name: "w2a2", WorkspaceID: w2ID,
|
|
Hosts: map[dnsname.FQDN][]netip.Addr{
|
|
"w2a2.w2.me.coder.": {w2a2IP},
|
|
"w2a2.w2.testy.coder.": {w2a2IP},
|
|
},
|
|
},
|
|
},
|
|
DeletedWorkspaces: []*tailnet.Workspace{},
|
|
DeletedAgents: []*tailnet.Agent{},
|
|
}
|
|
|
|
// And the callback
|
|
cbUpdate := testutil.RequireRecvCtx(ctx, t, fUH.ch)
|
|
require.Equal(t, currentState, cbUpdate)
|
|
|
|
// Current recvState should match
|
|
recvState, err := updateCtrl.CurrentState()
|
|
require.NoError(t, err)
|
|
slices.SortFunc(recvState.UpsertedWorkspaces, func(a, b *tailnet.Workspace) int {
|
|
return strings.Compare(a.Name, b.Name)
|
|
})
|
|
slices.SortFunc(recvState.UpsertedAgents, func(a, b *tailnet.Agent) int {
|
|
return strings.Compare(a.Name, b.Name)
|
|
})
|
|
require.Equal(t, currentState, recvState)
|
|
}
|
|
|
|
func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := testutil.Logger(t)
|
|
|
|
fUH := newFakeUpdateHandler(ctx, t)
|
|
fDNS := newFakeDNSSetter(ctx, t)
|
|
coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
|
|
tailnet.WithDNS(fDNS, "testy"),
|
|
tailnet.WithHandler(fUH),
|
|
)
|
|
|
|
w1ID := testUUID(1)
|
|
w1a1ID := testUUID(1, 1)
|
|
w1a2ID := testUUID(1, 2)
|
|
ws1a1IP := netip.MustParseAddr("fd60:627a:a42b:0101::")
|
|
ws1a2IP := netip.MustParseAddr("fd60:627a:a42b:0102::")
|
|
|
|
initUp := &proto.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*proto.Workspace{
|
|
{Id: w1ID[:], Name: "w1"},
|
|
},
|
|
UpsertedAgents: []*proto.Agent{
|
|
{Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]},
|
|
},
|
|
}
|
|
|
|
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
|
|
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp)
|
|
|
|
// Add for w1a1
|
|
coordCall := testutil.RequireRecvCtx(ctx, t, coordC.reqs)
|
|
require.Equal(t, w1a1ID[:], coordCall.req.GetAddTunnel().GetId())
|
|
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
|
|
|
|
// DNS for w1a1
|
|
expectedDNS := map[dnsname.FQDN][]netip.Addr{
|
|
"w1a1.w1.testy.coder.": {ws1a1IP},
|
|
"w1a1.w1.me.coder.": {ws1a1IP},
|
|
"w1.coder.": {ws1a1IP},
|
|
}
|
|
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
|
|
require.Equal(t, expectedDNS, dnsCall.hosts)
|
|
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
|
|
|
|
initRecvUp := tailnet.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*tailnet.Workspace{
|
|
{ID: w1ID, Name: "w1"},
|
|
},
|
|
UpsertedAgents: []*tailnet.Agent{
|
|
{ID: w1a1ID, Name: "w1a1", WorkspaceID: w1ID, Hosts: map[dnsname.FQDN][]netip.Addr{
|
|
"w1a1.w1.testy.coder.": {ws1a1IP},
|
|
"w1a1.w1.me.coder.": {ws1a1IP},
|
|
"w1.coder.": {ws1a1IP},
|
|
}},
|
|
},
|
|
DeletedWorkspaces: []*tailnet.Workspace{},
|
|
DeletedAgents: []*tailnet.Agent{},
|
|
}
|
|
|
|
cbUpdate := testutil.RequireRecvCtx(ctx, t, fUH.ch)
|
|
require.Equal(t, initRecvUp, cbUpdate)
|
|
|
|
// Current state should match initial
|
|
state, err := updateCtrl.CurrentState()
|
|
require.NoError(t, err)
|
|
require.Equal(t, initRecvUp, state)
|
|
|
|
// Send update that removes w1a1 and adds w1a2
|
|
agentUpdate := &proto.WorkspaceUpdate{
|
|
UpsertedAgents: []*proto.Agent{
|
|
{Id: w1a2ID[:], Name: "w1a2", WorkspaceId: w1ID[:]},
|
|
},
|
|
DeletedAgents: []*proto.Agent{
|
|
{Id: w1a1ID[:], WorkspaceId: w1ID[:]},
|
|
},
|
|
}
|
|
upRecvCall = testutil.RequireRecvCtx(ctx, t, updateC.recv)
|
|
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, agentUpdate)
|
|
|
|
// Add for w1a2
|
|
coordCall = testutil.RequireRecvCtx(ctx, t, coordC.reqs)
|
|
require.Equal(t, w1a2ID[:], coordCall.req.GetAddTunnel().GetId())
|
|
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
|
|
|
|
// Remove for w1a1
|
|
coordCall = testutil.RequireRecvCtx(ctx, t, coordC.reqs)
|
|
require.Equal(t, w1a1ID[:], coordCall.req.GetRemoveTunnel().GetId())
|
|
testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
|
|
|
|
// DNS contains only w1a2
|
|
expectedDNS = map[dnsname.FQDN][]netip.Addr{
|
|
"w1a2.w1.testy.coder.": {ws1a2IP},
|
|
"w1a2.w1.me.coder.": {ws1a2IP},
|
|
"w1.coder.": {ws1a2IP},
|
|
}
|
|
dnsCall = testutil.RequireRecvCtx(ctx, t, fDNS.calls)
|
|
require.Equal(t, expectedDNS, dnsCall.hosts)
|
|
testutil.RequireSendCtx(ctx, t, dnsCall.err, nil)
|
|
|
|
cbUpdate = testutil.RequireRecvCtx(ctx, t, fUH.ch)
|
|
sndRecvUpdate := tailnet.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*tailnet.Workspace{},
|
|
UpsertedAgents: []*tailnet.Agent{
|
|
{ID: w1a2ID, Name: "w1a2", WorkspaceID: w1ID, Hosts: map[dnsname.FQDN][]netip.Addr{
|
|
"w1a2.w1.testy.coder.": {ws1a2IP},
|
|
"w1a2.w1.me.coder.": {ws1a2IP},
|
|
"w1.coder.": {ws1a2IP},
|
|
}},
|
|
},
|
|
DeletedWorkspaces: []*tailnet.Workspace{},
|
|
DeletedAgents: []*tailnet.Agent{
|
|
{ID: w1a1ID, Name: "w1a1", WorkspaceID: w1ID, Hosts: map[dnsname.FQDN][]netip.Addr{
|
|
"w1a1.w1.testy.coder.": {ws1a1IP},
|
|
"w1a1.w1.me.coder.": {ws1a1IP},
|
|
"w1.coder.": {ws1a1IP},
|
|
}},
|
|
},
|
|
}
|
|
require.Equal(t, sndRecvUpdate, cbUpdate)
|
|
|
|
state, err = updateCtrl.CurrentState()
|
|
require.NoError(t, err)
|
|
require.Equal(t, tailnet.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*tailnet.Workspace{
|
|
{ID: w1ID, Name: "w1"},
|
|
},
|
|
UpsertedAgents: []*tailnet.Agent{
|
|
{ID: w1a2ID, Name: "w1a2", WorkspaceID: w1ID, Hosts: map[dnsname.FQDN][]netip.Addr{
|
|
"w1a2.w1.testy.coder.": {ws1a2IP},
|
|
"w1a2.w1.me.coder.": {ws1a2IP},
|
|
"w1.coder.": {ws1a2IP},
|
|
}},
|
|
},
|
|
DeletedWorkspaces: []*tailnet.Workspace{},
|
|
DeletedAgents: []*tailnet.Agent{},
|
|
}, state)
|
|
}
|
|
|
|
func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
dnsError := xerrors.New("a bad thing happened")
|
|
logger := slogtest.Make(t,
|
|
&slogtest.Options{IgnoredErrorIs: []error{dnsError}}).
|
|
Leveled(slog.LevelDebug)
|
|
|
|
fDNS := newFakeDNSSetter(ctx, t)
|
|
fConn := &fakeCoordinatee{}
|
|
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc,
|
|
tailnet.WithDNS(fDNS, "testy"),
|
|
)
|
|
|
|
updateC := newFakeWorkspaceUpdateClient(ctx, t)
|
|
updateCW := uut.New(updateC)
|
|
|
|
w1ID := testUUID(1)
|
|
w1a1ID := testUUID(1, 1)
|
|
ws1a1IP := netip.MustParseAddr("fd60:627a:a42b:0101::")
|
|
|
|
initUp := &proto.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*proto.Workspace{
|
|
{Id: w1ID[:], Name: "w1"},
|
|
},
|
|
UpsertedAgents: []*proto.Agent{
|
|
{Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:]},
|
|
},
|
|
}
|
|
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
|
|
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp)
|
|
|
|
// DNS for w1a1
|
|
expectedDNS := map[dnsname.FQDN][]netip.Addr{
|
|
"w1a1.w1.me.coder.": {ws1a1IP},
|
|
"w1a1.w1.testy.coder.": {ws1a1IP},
|
|
"w1.coder.": {ws1a1IP},
|
|
}
|
|
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
|
|
require.Equal(t, expectedDNS, dnsCall.hosts)
|
|
testutil.RequireSendCtx(ctx, t, dnsCall.err, dnsError)
|
|
|
|
// should trigger a close on the client
|
|
closeCall := testutil.RequireRecvCtx(ctx, t, updateC.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, io.EOF)
|
|
|
|
// error should be our initial DNS error
|
|
err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
|
|
require.ErrorIs(t, err, dnsError)
|
|
}
|
|
|
|
func TestTunnelAllWorkspaceUpdatesController_HandleErrors(t *testing.T) {
|
|
t.Parallel()
|
|
validWorkspaceID := testUUID(1)
|
|
validAgentID := testUUID(1, 1)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
update *proto.WorkspaceUpdate
|
|
errorContains string
|
|
}{
|
|
{
|
|
name: "unparsableUpsertWorkspaceID",
|
|
update: &proto.WorkspaceUpdate{
|
|
UpsertedWorkspaces: []*proto.Workspace{
|
|
{Id: []byte{2, 2}, Name: "bander"},
|
|
},
|
|
},
|
|
errorContains: "failed to parse workspace ID",
|
|
},
|
|
{
|
|
name: "unparsableDeleteWorkspaceID",
|
|
update: &proto.WorkspaceUpdate{
|
|
DeletedWorkspaces: []*proto.Workspace{
|
|
{Id: []byte{2, 2}, Name: "bander"},
|
|
},
|
|
},
|
|
errorContains: "failed to parse workspace ID",
|
|
},
|
|
{
|
|
name: "unparsableDeleteAgentWorkspaceID",
|
|
update: &proto.WorkspaceUpdate{
|
|
DeletedAgents: []*proto.Agent{
|
|
{Id: validAgentID[:], Name: "devo", WorkspaceId: []byte{2, 2}},
|
|
},
|
|
},
|
|
errorContains: "failed to parse workspace ID",
|
|
},
|
|
{
|
|
name: "unparsableUpsertAgentWorkspaceID",
|
|
update: &proto.WorkspaceUpdate{
|
|
UpsertedAgents: []*proto.Agent{
|
|
{Id: validAgentID[:], Name: "devo", WorkspaceId: []byte{2, 2}},
|
|
},
|
|
},
|
|
errorContains: "failed to parse workspace ID",
|
|
},
|
|
{
|
|
name: "unparsableDeleteAgentID",
|
|
update: &proto.WorkspaceUpdate{
|
|
DeletedAgents: []*proto.Agent{
|
|
{Id: []byte{2, 2}, Name: "devo", WorkspaceId: validWorkspaceID[:]},
|
|
},
|
|
},
|
|
errorContains: "failed to parse agent ID",
|
|
},
|
|
{
|
|
name: "unparsableUpsertAgentID",
|
|
update: &proto.WorkspaceUpdate{
|
|
UpsertedAgents: []*proto.Agent{
|
|
{Id: []byte{2, 2}, Name: "devo", WorkspaceId: validWorkspaceID[:]},
|
|
},
|
|
},
|
|
errorContains: "failed to parse agent ID",
|
|
},
|
|
{
|
|
name: "upsertAgentMissingWorkspace",
|
|
update: &proto.WorkspaceUpdate{
|
|
UpsertedAgents: []*proto.Agent{
|
|
{Id: validAgentID[:], Name: "devo", WorkspaceId: validWorkspaceID[:]},
|
|
},
|
|
},
|
|
errorContains: fmt.Sprintf("workspace %s not found", validWorkspaceID.String()),
|
|
},
|
|
{
|
|
name: "deleteAgentMissingWorkspace",
|
|
update: &proto.WorkspaceUpdate{
|
|
DeletedAgents: []*proto.Agent{
|
|
{Id: validAgentID[:], Name: "devo", WorkspaceId: validWorkspaceID[:]},
|
|
},
|
|
},
|
|
errorContains: fmt.Sprintf("workspace %s not found", validWorkspaceID.String()),
|
|
},
|
|
}
|
|
// nolint: paralleltest // no longer need to reinitialize loop vars in go 1.22
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
|
|
|
fConn := &fakeCoordinatee{}
|
|
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
|
|
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc)
|
|
updateC := newFakeWorkspaceUpdateClient(ctx, t)
|
|
updateCW := uut.New(updateC)
|
|
|
|
recvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
|
|
testutil.RequireSendCtx(ctx, t, recvCall.resp, tc.update)
|
|
closeCall := testutil.RequireRecvCtx(ctx, t, updateC.close)
|
|
testutil.RequireSendCtx(ctx, t, closeCall, nil)
|
|
|
|
err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
|
|
require.ErrorContains(t, err, tc.errorContains)
|
|
})
|
|
}
|
|
}
|
|
|
|
type fakeWorkspaceUpdatesController struct {
|
|
ctx context.Context
|
|
t testing.TB
|
|
calls chan *newWorkspaceUpdatesCall
|
|
}
|
|
|
|
func (*fakeWorkspaceUpdatesController) CurrentState() *proto.WorkspaceUpdate {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
type newWorkspaceUpdatesCall struct {
|
|
client tailnet.WorkspaceUpdatesClient
|
|
resp chan<- tailnet.CloserWaiter
|
|
}
|
|
|
|
func (f fakeWorkspaceUpdatesController) New(client tailnet.WorkspaceUpdatesClient) tailnet.CloserWaiter {
|
|
resps := make(chan tailnet.CloserWaiter)
|
|
call := &newWorkspaceUpdatesCall{
|
|
client: client,
|
|
resp: resps,
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
cw := newFakeCloserWaiter()
|
|
cw.errCh <- timeoutOnFakeErr
|
|
return cw
|
|
case f.calls <- call:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-f.ctx.Done():
|
|
cw := newFakeCloserWaiter()
|
|
cw.errCh <- timeoutOnFakeErr
|
|
return cw
|
|
case resp := <-resps:
|
|
return resp
|
|
}
|
|
}
|
|
|
|
func newFakeUpdatesController(ctx context.Context, t *testing.T) *fakeWorkspaceUpdatesController {
|
|
return &fakeWorkspaceUpdatesController{
|
|
ctx: ctx,
|
|
t: t,
|
|
calls: make(chan *newWorkspaceUpdatesCall),
|
|
}
|
|
}
|
|
|
|
type fakeCloserWaiter struct {
|
|
closeCalls chan chan error
|
|
errCh chan error
|
|
}
|
|
|
|
func (f *fakeCloserWaiter) Close(ctx context.Context) error {
|
|
errRes := make(chan error)
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case f.closeCalls <- errRes:
|
|
// OK
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-errRes:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (f *fakeCloserWaiter) Wait() <-chan error {
|
|
return f.errCh
|
|
}
|
|
|
|
func newFakeCloserWaiter() *fakeCloserWaiter {
|
|
return &fakeCloserWaiter{
|
|
closeCalls: make(chan chan error),
|
|
errCh: make(chan error, 1),
|
|
}
|
|
}
|
|
|
|
type fakeWorkspaceUpdatesDialer struct {
|
|
client tailnet.WorkspaceUpdatesClient
|
|
}
|
|
|
|
func (f *fakeWorkspaceUpdatesDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
|
return tailnet.ControlProtocolClients{
|
|
WorkspaceUpdates: f.client,
|
|
Closer: fakeCloser{},
|
|
}, nil
|
|
}
|
|
|
|
type fakeCloser struct{}
|
|
|
|
func (fakeCloser) Close() error {
|
|
return nil
|
|
}
|