chore: remove tailnet v1 API support (#14641)

Drops support for v1 of the tailnet API, which was the original coordination protocol where we only sent node updates, never marked them lost or disconnected.

v2 of the tailnet API went GA for CLI clients in Coder 2.8.0, so clients older than that would stop working.
This commit is contained in:
Spike Curtis
2024-09-12 07:56:31 +04:00
committed by GitHub
parent fb3523b37f
commit d6154c4310
14 changed files with 504 additions and 1412 deletions

View File

@ -10,8 +10,8 @@ import (
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet" "github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/tailnet/tailnettest"
agpltest "github.com/coder/coder/v2/tailnet/test"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -35,27 +35,27 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coord1.Close() defer coord1.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord1) ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close() defer ma1.Close()
ma1.RequireSubscribeAgent(agent1.id) ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5) ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.SendNodeWithDERP(3) ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.Close() ma1.Close()
require.NoError(t, agent1.close()) agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
assertEventuallyLost(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.ID)
} }
func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) { func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
@ -102,28 +102,28 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coord1.Close() defer coord1.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord1) ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close() defer ma1.Close()
ma1.RequireSubscribeAgent(agent1.id) ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5) ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.SendNodeWithDERP(3) ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.RequireUnsubscribeAgent(agent1.id) ma1.RequireUnsubscribeAgent(agent1.ID)
ma1.Close() ma1.Close()
require.NoError(t, agent1.close()) agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
assertEventuallyLost(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.ID)
} }
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a // TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
@ -147,43 +147,43 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coord1.Close() defer coord1.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord1) ma1 := tailnettest.NewTestMultiAgent(t, coord1)
defer ma1.Close() defer ma1.Close()
ma1.RequireSubscribeAgent(agent1.id) ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5) ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.SendNodeWithDERP(3) ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.RequireUnsubscribeAgent(agent1.id) ma1.RequireUnsubscribeAgent(agent1.ID)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
func() { func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel() defer cancel()
ma1.SendNodeWithDERP(9) ma1.SendNodeWithDERP(9)
assertNeverHasDERPs(ctx, t, agent1, 9) agent1.AssertNeverHasDERPs(ctx, ma1.ID, 9)
}() }()
func() { func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel() defer cancel()
agent1.sendNode(&agpl.Node{PreferredDERP: 8}) agent1.UpdateDERP(8)
ma1.RequireNeverHasDERPs(ctx, 8) ma1.RequireNeverHasDERPs(ctx, 8)
}() }()
ma1.Close() ma1.Close()
require.NoError(t, agent1.close()) agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
assertEventuallyLost(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.ID)
} }
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a // TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
@ -212,27 +212,27 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coord2.Close() defer coord2.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord2) ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close() defer ma1.Close()
ma1.RequireSubscribeAgent(agent1.id) ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5) ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.SendNodeWithDERP(3) ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.Close() ma1.Close()
require.NoError(t, agent1.close()) agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
assertEventuallyLost(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.ID)
} }
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two // TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
@ -262,27 +262,27 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
require.NoError(t, err) require.NoError(t, err)
defer coord2.Close() defer coord2.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.UpdateDERP(5)
ma1 := tailnettest.NewTestMultiAgent(t, coord2) ma1 := tailnettest.NewTestMultiAgent(t, coord2)
defer ma1.Close() defer ma1.Close()
ma1.SendNodeWithDERP(3) ma1.SendNodeWithDERP(3)
ma1.RequireSubscribeAgent(agent1.id) ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5) ma1.RequireEventuallyHasDERPs(ctx, 5)
assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.AssertEventuallyHasDERP(ma1.ID, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.Close() ma1.Close()
require.NoError(t, agent1.close()) agent1.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
assertEventuallyLost(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.ID)
} }
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a // TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
@ -317,37 +317,37 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coord3.Close() defer coord3.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
agent1.sendNode(&agpl.Node{PreferredDERP: 5}) agent1.UpdateDERP(5)
agent2 := newTestAgent(t, coord2, "agent2") agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2")
defer agent1.close() defer agent2.Close(ctx)
agent2.sendNode(&agpl.Node{PreferredDERP: 6}) agent2.UpdateDERP(6)
ma1 := tailnettest.NewTestMultiAgent(t, coord3) ma1 := tailnettest.NewTestMultiAgent(t, coord3)
defer ma1.Close() defer ma1.Close()
ma1.RequireSubscribeAgent(agent1.id) ma1.RequireSubscribeAgent(agent1.ID)
ma1.RequireEventuallyHasDERPs(ctx, 5) ma1.RequireEventuallyHasDERPs(ctx, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.RequireEventuallyHasDERPs(ctx, 1)
ma1.RequireSubscribeAgent(agent2.id) ma1.RequireSubscribeAgent(agent2.ID)
ma1.RequireEventuallyHasDERPs(ctx, 6) ma1.RequireEventuallyHasDERPs(ctx, 6)
agent2.sendNode(&agpl.Node{PreferredDERP: 2}) agent2.UpdateDERP(2)
ma1.RequireEventuallyHasDERPs(ctx, 2) ma1.RequireEventuallyHasDERPs(ctx, 2)
ma1.SendNodeWithDERP(3) ma1.SendNodeWithDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.AssertEventuallyHasDERP(ma1.ID, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3) agent2.AssertEventuallyHasDERP(ma1.ID, 3)
ma1.Close() ma1.Close()
require.NoError(t, agent1.close()) agent1.UngracefulDisconnect(ctx)
require.NoError(t, agent2.close()) agent2.UngracefulDisconnect(ctx)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
assertEventuallyLost(ctx, t, store, agent1.id) assertEventuallyLost(ctx, t, store, agent1.ID)
} }

View File

@ -3,7 +3,6 @@ package tailnet
import ( import (
"context" "context"
"database/sql" "database/sql"
"net"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -213,14 +212,6 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
return node return node
} }
func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
return agpl.ServeClientV1(c.ctx, c.logger, c, conn, id, agent)
}
func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
return agpl.ServeAgentV1(c.ctx, c.logger, c, conn, id, name)
}
func (c *pgCoord) Close() error { func (c *pgCoord) Close() error {
c.logger.Info(c.ctx, "closing coordinator") c.logger.Info(c.ctx, "closing coordinator")
c.cancel() c.cancel()

View File

@ -3,8 +3,6 @@ package tailnet_test
import ( import (
"context" "context"
"database/sql" "database/sql"
"io"
"net"
"net/netip" "net/netip"
"sync" "sync"
"testing" "testing"
@ -15,7 +13,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"golang.org/x/exp/slices"
"golang.org/x/xerrors" "golang.org/x/xerrors"
gProto "google.golang.org/protobuf/proto" gProto "google.golang.org/protobuf/proto"
@ -51,9 +48,9 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
defer coordinator.Close() defer coordinator.Close()
agentID := uuid.New() agentID := uuid.New()
client := newTestClient(t, coordinator, agentID) client := agpltest.NewClient(ctx, t, coordinator, "client", agentID)
defer client.close() defer client.Close(ctx)
client.sendNode(&agpl.Node{PreferredDERP: 10}) client.UpdateDERP(10)
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID) clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) { if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
@ -68,12 +65,8 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
assert.EqualValues(t, 10, node.PreferredDerp) assert.EqualValues(t, 10, node.PreferredDerp)
return true return true
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
client.UngracefulDisconnect(ctx)
err = client.close() assertEventuallyLost(ctx, t, store, client.ID)
require.NoError(t, err)
<-client.errChan
<-client.closeChan
assertEventuallyLost(ctx, t, store, client.id)
} }
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
@ -89,11 +82,11 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coordinator.Close() defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent") agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
defer agent.close() defer agent.Close(ctx)
agent.sendNode(&agpl.Node{PreferredDERP: 10}) agent.UpdateDERP(10)
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agent.id) agents, err := store.GetTailnetPeers(ctx, agent.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) { if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err) t.Fatalf("database error: %v", err)
} }
@ -106,11 +99,8 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
assert.EqualValues(t, 10, node.PreferredDerp) assert.EqualValues(t, 10, node.PreferredDerp)
return true return true
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
err = agent.close() agent.UngracefulDisconnect(ctx)
require.NoError(t, err) assertEventuallyLost(ctx, t, store, agent.ID)
<-agent.errChan
<-agent.closeChan
assertEventuallyLost(ctx, t, store, agent.id)
} }
func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) { func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
@ -126,18 +116,18 @@ func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coordinator.Close() defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent") agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
defer agent.close() defer agent.Close(ctx)
agent.sendNode(&agpl.Node{ agent.UpdateNode(&proto.Node{
Addresses: []netip.Prefix{ Addresses: []string{
netip.PrefixFrom(agpl.IP(), 128), netip.PrefixFrom(agpl.IP(), 128).String(),
}, },
PreferredDERP: 10, PreferredDerp: 10,
}) })
// The agent connection should be closed immediately after sending an invalid addr // The agent connection should be closed immediately after sending an invalid addr
testutil.RequireRecvCtx(ctx, t, agent.closeChan) agent.AssertEventuallyResponsesClosed()
assertEventuallyLost(ctx, t, store, agent.id) assertEventuallyLost(ctx, t, store, agent.ID)
} }
func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) { func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
@ -153,18 +143,18 @@ func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coordinator.Close() defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent") agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
defer agent.close() defer agent.Close(ctx)
agent.sendNode(&agpl.Node{ agent.UpdateNode(&proto.Node{
Addresses: []netip.Prefix{ Addresses: []string{
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 64), netip.PrefixFrom(agpl.IPFromUUID(agent.ID), 64).String(),
}, },
PreferredDERP: 10, PreferredDerp: 10,
}) })
// The agent connection should be closed immediately after sending an invalid addr // The agent connection should be closed immediately after sending an invalid addr
testutil.RequireRecvCtx(ctx, t, agent.closeChan) agent.AssertEventuallyResponsesClosed()
assertEventuallyLost(ctx, t, store, agent.id) assertEventuallyLost(ctx, t, store, agent.ID)
} }
func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) { func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
@ -180,16 +170,16 @@ func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coordinator.Close() defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent") agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
defer agent.close() defer agent.Close(ctx)
agent.sendNode(&agpl.Node{ agent.UpdateNode(&proto.Node{
Addresses: []netip.Prefix{ Addresses: []string{
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 128), netip.PrefixFrom(agpl.IPFromUUID(agent.ID), 128).String(),
}, },
PreferredDERP: 10, PreferredDerp: 10,
}) })
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
agents, err := store.GetTailnetPeers(ctx, agent.id) agents, err := store.GetTailnetPeers(ctx, agent.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) { if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err) t.Fatalf("database error: %v", err)
} }
@ -202,11 +192,8 @@ func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
assert.EqualValues(t, 10, node.PreferredDerp) assert.EqualValues(t, 10, node.PreferredDerp)
return true return true
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
err = agent.close() agent.UngracefulDisconnect(ctx)
require.NoError(t, err) assertEventuallyLost(ctx, t, store, agent.ID)
<-agent.errChan
<-agent.closeChan
assertEventuallyLost(ctx, t, store, agent.id)
} }
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
@ -222,68 +209,40 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coordinator.Close() defer coordinator.Close()
agent := newTestAgent(t, coordinator, "original") agent := agpltest.NewAgent(ctx, t, coordinator, "original")
defer agent.close() defer agent.Close(ctx)
agent.sendNode(&agpl.Node{PreferredDERP: 10}) agent.UpdateDERP(10)
client := newTestClient(t, coordinator, agent.id) client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
defer client.close() defer client.Close(ctx)
agentNodes := client.recvNodes(ctx, t) client.AssertEventuallyHasDERP(agent.ID, 10)
require.Len(t, agentNodes, 1) client.UpdateDERP(11)
assert.Equal(t, 10, agentNodes[0].PreferredDERP) agent.AssertEventuallyHasDERP(client.ID, 11)
client.sendNode(&agpl.Node{PreferredDERP: 11})
clientNodes := agent.recvNodes(ctx, t)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
// Ensure an update to the agent node reaches the connIO! // Ensure an update to the agent node reaches the connIO!
agent.sendNode(&agpl.Node{PreferredDERP: 12}) agent.UpdateDERP(12)
agentNodes = client.recvNodes(ctx, t) client.AssertEventuallyHasDERP(agent.ID, 12)
require.Len(t, agentNodes, 1)
assert.Equal(t, 12, agentNodes[0].PreferredDERP)
// Close the agent WebSocket so a new one can connect. // Close the agent channel so a new one can connect.
err = agent.close() agent.Close(ctx)
require.NoError(t, err)
_ = agent.recvErr(ctx, t)
agent.waitForClose(ctx, t)
// Create a new agent connection. This is to simulate a reconnect! // Create a new agent connection. This is to simulate a reconnect!
agent = newTestAgent(t, coordinator, "reconnection", agent.id) agent = agpltest.NewPeer(ctx, t, coordinator, "reconnection", agpltest.WithID(agent.ID))
// Ensure the existing listening connIO sends its node immediately! // Ensure the coordinator sends its client node immediately!
clientNodes = agent.recvNodes(ctx, t) agent.AssertEventuallyHasDERP(client.ID, 11)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
// Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the // Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the
// coordinator accidentally reordering things. // coordinator accidentally reordering things.
for d := 13; d < 36; d++ { for d := int32(13); d < 36; d++ {
agent.sendNode(&agpl.Node{PreferredDERP: d}) agent.UpdateDERP(d)
}
for {
nodes := client.recvNodes(ctx, t)
if !assert.Len(t, nodes, 1) {
break
}
if nodes[0].PreferredDERP == 35 {
// got latest!
break
}
} }
client.AssertEventuallyHasDERP(agent.ID, 35)
err = agent.close() agent.UngracefulDisconnect(ctx)
require.NoError(t, err) client.UngracefulDisconnect(ctx)
_ = agent.recvErr(ctx, t) assertEventuallyLost(ctx, t, store, agent.ID)
agent.waitForClose(ctx, t) assertEventuallyLost(ctx, t, store, client.ID)
err = client.close()
require.NoError(t, err)
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyLost(ctx, t, store, agent.id)
assertEventuallyLost(ctx, t, store, client.id)
} }
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
@ -305,16 +264,16 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coordinator.Close() defer coordinator.Close()
agent := newTestAgent(t, coordinator, "agent") agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
defer agent.close() defer agent.Close(ctx)
agent.sendNode(&agpl.Node{PreferredDERP: 10}) agent.UpdateDERP(10)
client := newTestClient(t, coordinator, agent.id) client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
defer client.close() defer client.Close(ctx)
assertEventuallyHasDERPs(ctx, t, client, 10) client.AssertEventuallyHasDERP(agent.ID, 10)
client.sendNode(&agpl.Node{PreferredDERP: 11}) client.UpdateDERP(11)
assertEventuallyHasDERPs(ctx, t, agent, 11) agent.AssertEventuallyHasDERP(client.ID, 11)
// simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a // simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a
// real coordinator // real coordinator
@ -328,8 +287,8 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
fCoord2.heartbeat() fCoord2.heartbeat()
afTrap.MustWait(ctx).Release() // heartbeat timeout started afTrap.MustWait(ctx).Release() // heartbeat timeout started
fCoord2.agentNode(agent.id, &agpl.Node{PreferredDERP: 12}) fCoord2.agentNode(agent.ID, &agpl.Node{PreferredDERP: 12})
assertEventuallyHasDERPs(ctx, t, client, 12) client.AssertEventuallyHasDERP(agent.ID, 12)
fCoord3 := &fakeCoordinator{ fCoord3 := &fakeCoordinator{
ctx: ctx, ctx: ctx,
@ -339,8 +298,8 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
} }
fCoord3.heartbeat() fCoord3.heartbeat()
rstTrap.MustWait(ctx).Release() // timeout gets reset rstTrap.MustWait(ctx).Release() // timeout gets reset
fCoord3.agentNode(agent.id, &agpl.Node{PreferredDERP: 13}) fCoord3.agentNode(agent.ID, &agpl.Node{PreferredDERP: 13})
assertEventuallyHasDERPs(ctx, t, client, 13) client.AssertEventuallyHasDERP(agent.ID, 13)
// fCoord2 sends in a second heartbeat, one period later (on time) // fCoord2 sends in a second heartbeat, one period later (on time)
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
@ -353,30 +312,22 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
w := mClock.Advance(tailnet.HeartbeatPeriod) w := mClock.Advance(tailnet.HeartbeatPeriod)
rstTrap.MustWait(ctx).Release() rstTrap.MustWait(ctx).Release()
w.MustWait(ctx) w.MustWait(ctx)
assertEventuallyHasDERPs(ctx, t, client, 12) client.AssertEventuallyHasDERP(agent.ID, 12)
// one more heartbeat period will result in fCoord2 being expired, which should cause us to // one more heartbeat period will result in fCoord2 being expired, which should cause us to
// revert to the original agent mapping // revert to the original agent mapping
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
// note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired // note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired
assertEventuallyHasDERPs(ctx, t, client, 10) client.AssertEventuallyHasDERP(agent.ID, 10)
// send fCoord3 heartbeat, which should trigger us to consider that mapping valid again. // send fCoord3 heartbeat, which should trigger us to consider that mapping valid again.
fCoord3.heartbeat() fCoord3.heartbeat()
rstTrap.MustWait(ctx).Release() // timeout gets reset rstTrap.MustWait(ctx).Release() // timeout gets reset
assertEventuallyHasDERPs(ctx, t, client, 13) client.AssertEventuallyHasDERP(agent.ID, 13)
err = agent.close() agent.UngracefulDisconnect(ctx)
require.NoError(t, err) client.UngracefulDisconnect(ctx)
_ = agent.recvErr(ctx, t) assertEventuallyLost(ctx, t, store, client.ID)
agent.waitForClose(ctx, t)
err = client.close()
require.NoError(t, err)
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyLost(ctx, t, store, client.id)
} }
func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) { func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) {
@ -420,7 +371,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) {
// disconnect. // disconnect.
client.AssertEventuallyLost(agentID) client.AssertEventuallyLost(agentID)
client.Close(ctx) client.UngracefulDisconnect(ctx)
assertEventuallyLost(ctx, t, store, client.ID) assertEventuallyLost(ctx, t, store, client.ID)
} }
@ -491,104 +442,73 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coord2.Close() defer coord2.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
t.Logf("agent1=%s", agent1.id) t.Logf("agent1=%s", agent1.ID)
agent2 := newTestAgent(t, coord2, "agent2") agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2")
defer agent2.close() defer agent2.Close(ctx)
t.Logf("agent2=%s", agent2.id) t.Logf("agent2=%s", agent2.ID)
client11 := newTestClient(t, coord1, agent1.id) client11 := agpltest.NewClient(ctx, t, coord1, "client11", agent1.ID)
defer client11.close() defer client11.Close(ctx)
t.Logf("client11=%s", client11.id) t.Logf("client11=%s", client11.ID)
client12 := newTestClient(t, coord1, agent2.id) client12 := agpltest.NewClient(ctx, t, coord1, "client12", agent2.ID)
defer client12.close() defer client12.Close(ctx)
t.Logf("client12=%s", client12.id) t.Logf("client12=%s", client12.ID)
client21 := newTestClient(t, coord2, agent1.id) client21 := agpltest.NewClient(ctx, t, coord2, "client21", agent1.ID)
defer client21.close() defer client21.Close(ctx)
t.Logf("client21=%s", client21.id) t.Logf("client21=%s", client21.ID)
client22 := newTestClient(t, coord2, agent2.id) client22 := agpltest.NewClient(ctx, t, coord2, "client22", agent2.ID)
defer client22.close() defer client22.Close(ctx)
t.Logf("client22=%s", client22.id) t.Logf("client22=%s", client22.ID)
t.Logf("client11 -> Node 11") t.Logf("client11 -> Node 11")
client11.sendNode(&agpl.Node{PreferredDERP: 11}) client11.UpdateDERP(11)
assertEventuallyHasDERPs(ctx, t, agent1, 11) agent1.AssertEventuallyHasDERP(client11.ID, 11)
t.Logf("client21 -> Node 21") t.Logf("client21 -> Node 21")
client21.sendNode(&agpl.Node{PreferredDERP: 21}) client21.UpdateDERP(21)
assertEventuallyHasDERPs(ctx, t, agent1, 21) agent1.AssertEventuallyHasDERP(client21.ID, 21)
t.Logf("client22 -> Node 22") t.Logf("client22 -> Node 22")
client22.sendNode(&agpl.Node{PreferredDERP: 22}) client22.UpdateDERP(22)
assertEventuallyHasDERPs(ctx, t, agent2, 22) agent2.AssertEventuallyHasDERP(client22.ID, 22)
t.Logf("agent2 -> Node 2") t.Logf("agent2 -> Node 2")
agent2.sendNode(&agpl.Node{PreferredDERP: 2}) agent2.UpdateDERP(2)
assertEventuallyHasDERPs(ctx, t, client22, 2) client22.AssertEventuallyHasDERP(agent2.ID, 2)
assertEventuallyHasDERPs(ctx, t, client12, 2) client12.AssertEventuallyHasDERP(agent2.ID, 2)
t.Logf("client12 -> Node 12") t.Logf("client12 -> Node 12")
client12.sendNode(&agpl.Node{PreferredDERP: 12}) client12.UpdateDERP(12)
assertEventuallyHasDERPs(ctx, t, agent2, 12) agent2.AssertEventuallyHasDERP(client12.ID, 12)
t.Logf("agent1 -> Node 1") t.Logf("agent1 -> Node 1")
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
assertEventuallyHasDERPs(ctx, t, client21, 1) client21.AssertEventuallyHasDERP(agent1.ID, 1)
assertEventuallyHasDERPs(ctx, t, client11, 1) client11.AssertEventuallyHasDERP(agent1.ID, 1)
t.Logf("close coord2") t.Logf("close coord2")
err = coord2.Close() err = coord2.Close()
require.NoError(t, err) require.NoError(t, err)
// this closes agent2, client22, client21 // this closes agent2, client22, client21
err = agent2.recvErr(ctx, t) agent2.AssertEventuallyResponsesClosed()
require.ErrorIs(t, err, io.EOF) client22.AssertEventuallyResponsesClosed()
err = client22.recvErr(ctx, t) client21.AssertEventuallyResponsesClosed()
require.ErrorIs(t, err, io.EOF) assertEventuallyLost(ctx, t, store, agent2.ID)
err = client21.recvErr(ctx, t) assertEventuallyLost(ctx, t, store, client21.ID)
require.ErrorIs(t, err, io.EOF) assertEventuallyLost(ctx, t, store, client22.ID)
assertEventuallyLost(ctx, t, store, agent2.id)
assertEventuallyLost(ctx, t, store, client21.id)
assertEventuallyLost(ctx, t, store, client22.id)
err = coord1.Close() err = coord1.Close()
require.NoError(t, err) require.NoError(t, err)
// this closes agent1, client12, client11 // this closes agent1, client12, client11
err = agent1.recvErr(ctx, t) agent1.AssertEventuallyResponsesClosed()
require.ErrorIs(t, err, io.EOF) client12.AssertEventuallyResponsesClosed()
err = client12.recvErr(ctx, t) client11.AssertEventuallyResponsesClosed()
require.ErrorIs(t, err, io.EOF) assertEventuallyLost(ctx, t, store, agent1.ID)
err = client11.recvErr(ctx, t) assertEventuallyLost(ctx, t, store, client11.ID)
require.ErrorIs(t, err, io.EOF) assertEventuallyLost(ctx, t, store, client12.ID)
assertEventuallyLost(ctx, t, store, agent1.id)
assertEventuallyLost(ctx, t, store, client11.id)
assertEventuallyLost(ctx, t, store, client12.id)
// wait for all connections to close
err = agent1.close()
require.NoError(t, err)
agent1.waitForClose(ctx, t)
err = agent2.close()
require.NoError(t, err)
agent2.waitForClose(ctx, t)
err = client11.close()
require.NoError(t, err)
client11.waitForClose(ctx, t)
err = client12.close()
require.NoError(t, err)
client12.waitForClose(ctx, t)
err = client21.close()
require.NoError(t, err)
client21.waitForClose(ctx, t)
err = client22.close()
require.NoError(t, err)
client22.waitForClose(ctx, t)
} }
// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators. // TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
@ -623,53 +543,42 @@ func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coord3.Close() defer coord3.Close()
agent1 := newTestAgent(t, coord1, "agent1") agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
defer agent1.close() defer agent1.Close(ctx)
agent2 := newTestAgent(t, coord2, "agent2", agent1.id) agent2 := agpltest.NewPeer(ctx, t, coord2, "agent2",
defer agent2.close() agpltest.WithID(agent1.ID), agpltest.WithAuth(agpl.AgentCoordinateeAuth{ID: agent1.ID}),
)
defer agent2.Close(ctx)
client := newTestClient(t, coord3, agent1.id) client := agpltest.NewClient(ctx, t, coord3, "client", agent1.ID)
defer client.close() defer client.Close(ctx)
client.sendNode(&agpl.Node{PreferredDERP: 3}) client.UpdateDERP(3)
assertEventuallyHasDERPs(ctx, t, agent1, 3) agent1.AssertEventuallyHasDERP(client.ID, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3) agent2.AssertEventuallyHasDERP(client.ID, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1}) agent1.UpdateDERP(1)
assertEventuallyHasDERPs(ctx, t, client, 1) client.AssertEventuallyHasDERP(agent1.ID, 1)
// agent2's update overrides agent1 because it is newer // agent2's update overrides agent1 because it is newer
agent2.sendNode(&agpl.Node{PreferredDERP: 2}) agent2.UpdateDERP(2)
assertEventuallyHasDERPs(ctx, t, client, 2) client.AssertEventuallyHasDERP(agent1.ID, 2)
// agent2 disconnects, and we should revert back to agent1 // agent2 disconnects, and we should revert back to agent1
err = agent2.close() agent2.Close(ctx)
require.NoError(t, err) client.AssertEventuallyHasDERP(agent1.ID, 1)
err = agent2.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
agent2.waitForClose(ctx, t)
assertEventuallyHasDERPs(ctx, t, client, 1)
agent1.sendNode(&agpl.Node{PreferredDERP: 11}) agent1.UpdateDERP(11)
assertEventuallyHasDERPs(ctx, t, client, 11) client.AssertEventuallyHasDERP(agent1.ID, 11)
client.sendNode(&agpl.Node{PreferredDERP: 31}) client.UpdateDERP(31)
assertEventuallyHasDERPs(ctx, t, agent1, 31) agent1.AssertEventuallyHasDERP(client.ID, 31)
err = agent1.close() agent1.UngracefulDisconnect(ctx)
require.NoError(t, err) client.UngracefulDisconnect(ctx)
err = agent1.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
agent1.waitForClose(ctx, t)
err = client.close() assertEventuallyLost(ctx, t, store, client.ID)
require.NoError(t, err) assertEventuallyLost(ctx, t, store, agent1.ID)
err = client.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
client.waitForClose(ctx, t)
assertEventuallyLost(ctx, t, store, client.id)
assertEventuallyLost(ctx, t, store, agent1.id)
} }
func TestPGCoordinator_Unhealthy(t *testing.T) { func TestPGCoordinator_Unhealthy(t *testing.T) {
@ -683,7 +592,13 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
calls := make(chan struct{}) calls := make(chan struct{})
// first call succeeds, so that our Agent will successfully connect.
firstSucceeds := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
Times(1).
Return(database.TailnetCoordinator{}, nil)
// next 3 fail, so the Coordinator becomes unhealthy, and we test that it disconnects the agent
threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
After(firstSucceeds).
Times(3). Times(3).
Do(func(_ context.Context, _ uuid.UUID) { <-calls }). Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
Return(database.TailnetCoordinator{}, xerrors.New("test disconnect")) Return(database.TailnetCoordinator{}, xerrors.New("test disconnect"))
@ -710,23 +625,23 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
err := uut.Close() err := uut.Close()
require.NoError(t, err) require.NoError(t, err)
}() }()
agent1 := newTestAgent(t, uut, "agent1") agent1 := agpltest.NewAgent(ctx, t, uut, "agent1")
defer agent1.close() defer agent1.Close(ctx)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
select { select {
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatalf("timeout waiting for call %d", i+1)
case calls <- struct{}{}: case calls <- struct{}{}:
// OK // OK
} }
} }
// connected agent should be disconnected // connected agent should be disconnected
agent1.waitForClose(ctx, t) agent1.AssertEventuallyResponsesClosed()
// new agent should immediately disconnect // new agent should immediately disconnect
agent2 := newTestAgent(t, uut, "agent2") agent2 := agpltest.NewAgent(ctx, t, uut, "agent2")
defer agent2.close() defer agent2.Close(ctx)
agent2.waitForClose(ctx, t) agent2.AssertEventuallyResponsesClosed()
// next heartbeats succeed, so we are healthy // next heartbeats succeed, so we are healthy
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
@ -737,14 +652,9 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
// OK // OK
} }
} }
agent3 := newTestAgent(t, uut, "agent3") agent3 := agpltest.NewAgent(ctx, t, uut, "agent3")
defer agent3.close() defer agent3.Close(ctx)
select { agent3.AssertNotClosed(time.Second)
case <-agent3.closeChan:
t.Fatal("agent conn closed after we are healthy")
case <-time.After(time.Second):
// OK
}
} }
func TestPGCoordinator_Node_Empty(t *testing.T) { func TestPGCoordinator_Node_Empty(t *testing.T) {
@ -840,43 +750,39 @@ func TestPGCoordinator_NoDeleteOnClose(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer coordinator.Close() defer coordinator.Close()
agent := newTestAgent(t, coordinator, "original") agent := agpltest.NewAgent(ctx, t, coordinator, "original")
defer agent.close() defer agent.Close(ctx)
agent.sendNode(&agpl.Node{PreferredDERP: 10}) agent.UpdateDERP(10)
client := newTestClient(t, coordinator, agent.id) client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
defer client.close() defer client.Close(ctx)
// Simulate some traffic to generate // Simulate some traffic to generate
// a peer. // a peer.
agentNodes := client.recvNodes(ctx, t) client.AssertEventuallyHasDERP(agent.ID, 10)
require.Len(t, agentNodes, 1) client.UpdateDERP(11)
assert.Equal(t, 10, agentNodes[0].PreferredDERP)
client.sendNode(&agpl.Node{PreferredDERP: 11})
clientNodes := agent.recvNodes(ctx, t) agent.AssertEventuallyHasDERP(client.ID, 11)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
anode := coordinator.Node(agent.id) anode := coordinator.Node(agent.ID)
require.NotNil(t, anode) require.NotNil(t, anode)
cnode := coordinator.Node(client.id) cnode := coordinator.Node(client.ID)
require.NotNil(t, cnode) require.NotNil(t, cnode)
err = coordinator.Close() err = coordinator.Close()
require.NoError(t, err) require.NoError(t, err)
assertEventuallyLost(ctx, t, store, agent.id) assertEventuallyLost(ctx, t, store, agent.ID)
assertEventuallyLost(ctx, t, store, client.id) assertEventuallyLost(ctx, t, store, client.ID)
coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store) coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
require.NoError(t, err) require.NoError(t, err)
defer coordinator2.Close() defer coordinator2.Close()
anode = coordinator2.Node(agent.id) anode = coordinator2.Node(agent.ID)
require.NotNil(t, anode) require.NotNil(t, anode)
assert.Equal(t, 10, anode.PreferredDERP) assert.Equal(t, 10, anode.PreferredDERP)
cnode = coordinator2.Node(client.id) cnode = coordinator2.Node(client.ID)
require.NotNil(t, cnode) require.NotNil(t, cnode)
assert.Equal(t, 11, cnode.PreferredDERP) assert.Equal(t, 11, cnode.PreferredDERP)
} }
@ -1007,144 +913,6 @@ func TestPGCoordinatorDual_PeerReconnect(t *testing.T) {
p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED)
} }
type testConn struct {
ws, serverWS net.Conn
nodeChan chan []*agpl.Node
sendNode func(node *agpl.Node)
errChan <-chan error
id uuid.UUID
closeChan chan struct{}
}
func newTestConn(ids []uuid.UUID) *testConn {
a := &testConn{}
a.ws, a.serverWS = net.Pipe()
a.nodeChan = make(chan []*agpl.Node)
a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error {
a.nodeChan <- nodes
return nil
})
if len(ids) > 1 {
panic("too many")
}
if len(ids) == 1 {
a.id = ids[0]
} else {
a.id = uuid.New()
}
a.closeChan = make(chan struct{})
return a
}
func newTestAgent(t *testing.T, coord agpl.CoordinatorV1, name string, id ...uuid.UUID) *testConn {
a := newTestConn(id)
go func() {
err := coord.ServeAgent(a.serverWS, a.id, name)
assert.NoError(t, err)
close(a.closeChan)
}()
return a
}
func newTestClient(t *testing.T, coord agpl.CoordinatorV1, agentID uuid.UUID, id ...uuid.UUID) *testConn {
c := newTestConn(id)
go func() {
err := coord.ServeClient(c.serverWS, c.id, agentID)
assert.NoError(t, err)
close(c.closeChan)
}()
return c
}
func (c *testConn) close() error {
return c.ws.Close()
}
func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node {
t.Helper()
select {
case <-ctx.Done():
t.Fatalf("testConn id %s: timeout receiving nodes ", c.id)
return nil
case nodes := <-c.nodeChan:
return nodes
}
}
func (c *testConn) recvErr(ctx context.Context, t *testing.T) error {
t.Helper()
// pgCoord works on eventual consistency, so it sometimes sends extra node
// updates, and these block errors if not read from the nodes channel.
for {
select {
case nodes := <-c.nodeChan:
t.Logf("ignoring nodes update while waiting for error; id=%s, nodes=%+v",
c.id.String(), nodes)
continue
case <-ctx.Done():
t.Fatal("timeout receiving error")
return ctx.Err()
case err := <-c.errChan:
return err
}
}
}
func (c *testConn) waitForClose(ctx context.Context, t *testing.T) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connection to close")
return
case <-c.closeChan:
return
}
}
func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
t.Helper()
for {
nodes := c.recvNodes(ctx, t)
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
t.Helper()
for {
select {
case <-ctx.Done():
return
case nodes := <-c.nodeChan:
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if slices.Contains(derps, e) {
t.Fatalf("expected not to get DERP %d, but received it", e)
return
}
}
}
}
}
func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) { func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) {
t.Helper() t.Helper()
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {

View File

@ -1,19 +1,13 @@
package tailnet package tailnet
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors"
"net" "net"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
agpl "github.com/coder/coder/v2/tailnet" agpl "github.com/coder/coder/v2/tailnet"
) )
@ -38,10 +32,6 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin
return err return err
} }
switch major { switch major {
case 1:
coord := *(s.CoordPtr.Load())
sub := coord.ServeMultiAgent(id)
return ServeWorkspaceProxy(ctx, conn, sub)
case 2: case 2:
auth := agpl.SingleTailnetCoordinateeAuth{} auth := agpl.SingleTailnetCoordinateeAuth{}
streamID := agpl.StreamID{ streamID := agpl.StreamID{
@ -52,103 +42,6 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin
return s.ServeConnV2(ctx, conn, streamID) return s.ServeConnV2(ctx, conn, streamID)
default: default:
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
return xerrors.New("unsupported version") return agpl.ErrUnsupportedVersion
}
}
func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
go func() {
//nolint:staticcheck
err := forwardNodesToWorkspaceProxy(ctx, conn, ma)
//nolint:staticcheck
if err != nil {
_ = conn.Close()
}
}()
decoder := json.NewDecoder(conn)
for {
var msg wsproxysdk.CoordinateMessage
err := decoder.Decode(&msg)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return xerrors.Errorf("read json: %w", err)
}
switch msg.Type {
case wsproxysdk.CoordinateMessageTypeSubscribe:
err := ma.SubscribeAgent(msg.AgentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
case wsproxysdk.CoordinateMessageTypeUnsubscribe:
err := ma.UnsubscribeAgent(msg.AgentID)
if err != nil {
return xerrors.Errorf("unsubscribe agent: %w", err)
}
case wsproxysdk.CoordinateMessageTypeNodeUpdate:
pn, err := agpl.NodeToProto(msg.Node)
if err != nil {
return err
}
err = ma.UpdateSelf(pn)
if err != nil {
return xerrors.Errorf("update self: %w", err)
}
default:
return xerrors.Errorf("unknown message type %q", msg.Type)
}
}
}
// Linter fails because this function always returns an error. This function blocks
// until it errors, so this is ok.
//
//nolint:staticcheck
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
var lastData []byte
for {
resp, ok := ma.NextUpdate(ctx)
if !ok {
return xerrors.New("multiagent is closed")
}
nodes, err := agpl.OnlyNodeUpdates(resp)
if err != nil {
return xerrors.Errorf("failed to convert response: %w", err)
}
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
if err != nil {
return err
}
if bytes.Equal(lastData, data) {
continue
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = conn.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
return err
}
_, err = conn.Write(data)
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
return err
}
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = conn.SetWriteDeadline(time.Time{})
if err != nil {
return err
}
lastData = data
} }
} }

View File

@ -11,6 +11,7 @@ import {
} from "../helpers"; } from "../helpers";
import { beforeCoderTest } from "../hooks"; import { beforeCoderTest } from "../hooks";
// we no longer support versions prior to Tailnet v2 API support: https://github.com/coder/coder/commit/059e533544a0268acbc8831006b2858ead2f0d8e
const clientVersion = "v2.8.0"; const clientVersion = "v2.8.0";
test.beforeEach(({ page }) => beforeCoderTest(page)); test.beforeEach(({ page }) => beforeCoderTest(page));

View File

@ -2,11 +2,9 @@ package tailnet
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"html/template" "html/template"
"io" "io"
"net"
"net/http" "net/http"
"net/netip" "net/netip"
"sync" "sync"
@ -14,7 +12,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"nhooyr.io/websocket"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -22,35 +19,23 @@ import (
"github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/proto"
) )
const (
// ResponseBufferSize is the max number of responses to buffer per connection before we start
// dropping updates
ResponseBufferSize = 512
// RequestBufferSize is the max number of requests to buffer per connection
RequestBufferSize = 32
)
// Coordinator exchanges nodes with agents to establish connections. // Coordinator exchanges nodes with agents to establish connections.
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
// Coordinators have different guarantees for HA support. // Coordinators have different guarantees for HA support.
type Coordinator interface { type Coordinator interface {
CoordinatorV1
CoordinatorV2 CoordinatorV2
} }
type CoordinatorV1 interface {
// ServeHTTPDebug serves a debug webpage that shows the internal state of
// the coordinator.
ServeHTTPDebug(w http.ResponseWriter, r *http.Request)
// Node returns an in-memory node by ID.
Node(id uuid.UUID) *Node
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error
// ServeAgent accepts a WebSocket connection to an agent that listens to
// incoming connections and publishes node updates.
// Name is just used for debug information. It can be left blank.
ServeAgent(conn net.Conn, id uuid.UUID, name string) error
// Close closes the coordinator.
Close() error
ServeMultiAgent(id uuid.UUID) MultiAgentConn
}
// CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API. // CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API.
type CoordinatorV2 interface { type CoordinatorV2 interface {
// ServeHTTPDebug serves a debug webpage that shows the internal state of // ServeHTTPDebug serves a debug webpage that shows the internal state of
@ -60,6 +45,7 @@ type CoordinatorV2 interface {
Node(id uuid.UUID) *Node Node(id uuid.UUID) *Node
Close() error Close() error
Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
ServeMultiAgent(id uuid.UUID) MultiAgentConn
} }
// Node represents a node in the network. // Node represents a node in the network.
@ -389,44 +375,6 @@ func (c *inMemoryCoordination) Close() error {
} }
} }
// ServeCoordinator matches the RW structure of a coordinator to exchange node messages.
func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) {
errChan := make(chan error, 1)
sendErr := func(err error) {
select {
case errChan <- err:
default:
}
}
go func() {
decoder := json.NewDecoder(conn)
for {
var nodes []*Node
err := decoder.Decode(&nodes)
if err != nil {
sendErr(xerrors.Errorf("read: %w", err))
return
}
err = updateNodes(nodes)
if err != nil {
sendErr(xerrors.Errorf("update nodes: %w", err))
}
}
}()
return func(node *Node) {
data, err := json.Marshal(node)
if err != nil {
sendErr(xerrors.Errorf("marshal node: %w", err))
return
}
_, err = conn.Write(data)
if err != nil {
sendErr(xerrors.Errorf("write: %w", err))
}
}, errChan
}
const LoggerName = "coord" const LoggerName = "coord"
var ( var (
@ -540,11 +488,11 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge
}, },
}).Init() }).Init()
go v1RespLoop(ctx, cancel, logger, m, resps) go qRespLoop(ctx, cancel, logger, m, resps)
return m return m
} }
// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; // core is an in-memory structure of peer mappings. Its methods may be called from multiple goroutines;
// it is protected by a mutex to ensure data stay consistent. // it is protected by a mutex to ensure data stay consistent.
type core struct { type core struct {
logger slog.Logger logger slog.Logger
@ -607,42 +555,6 @@ func (c *core) node(id uuid.UUID) *Node {
return v1Node return v1Node
} }
// ServeClient accepts a WebSocket connection that wants to connect to an agent
// with the specified ID.
func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return ServeClientV1(ctx, c.core.logger, c, conn, id, agentID)
}
// ServeClientV1 adapts a v1 Client to a v2 Coordinator
func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
logger = logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
defer func() {
err := conn.Close()
if err != nil {
logger.Debug(ctx, "closing client connection", slog.Error(err))
}
}()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
reqs, resps := c.Coordinate(ctx, id, id.String(), ClientCoordinateeAuth{AgentID: agent})
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)},
})
if err != nil {
// can only be a context error, no need to log here.
return err
}
tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient)
go tc.SendUpdates()
go v1RespLoop(ctx, cancel, logger, tc, resps)
go v1ReqLoop(ctx, cancel, logger, conn, reqs)
<-ctx.Done()
return nil
}
func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error { func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
@ -887,34 +799,6 @@ func (c *core) removePeerLocked(id uuid.UUID, kind proto.CoordinateResponse_Peer
delete(c.peers, id) delete(c.peers, id)
} }
// ServeAgent accepts a WebSocket connection to an agent that
// listens to incoming connections and publishes node updates.
func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return ServeAgentV1(ctx, c.core.logger, c, conn, id, name)
}
func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, name string) error {
logger = logger.With(slog.F("agent_id", id), slog.F("name", name))
defer func() {
logger.Debug(ctx, "closing agent connection")
err := conn.Close()
logger.Debug(ctx, "closed agent connection", slog.Error(err))
}()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger.Debug(ctx, "starting new agent connection")
reqs, resps := c.Coordinate(ctx, id, name, AgentCoordinateeAuth{ID: id})
tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent)
go tc.SendUpdates()
go v1RespLoop(ctx, cancel, logger, tc, resps)
go v1ReqLoop(ctx, cancel, logger, conn, reqs)
<-ctx.Done()
logger.Debug(ctx, "ending agent connection")
return nil
}
// Close closes all of the open connections in the coordinator and stops the // Close closes all of the open connections in the coordinator and stops the
// coordinator from accepting new connections. // coordinator from accepting new connections.
func (c *coordinator) Close() error { func (c *coordinator) Close() error {
@ -1073,44 +957,7 @@ func RecvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
} }
} }
func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, func qRespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
conn net.Conn, reqs chan<- *proto.CoordinateRequest,
) {
defer close(reqs)
defer cancel()
decoder := json.NewDecoder(conn)
for {
var node Node
err := decoder.Decode(&node)
if err != nil {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, io.ErrClosedPipe) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
websocket.CloseStatus(err) > 0 {
logger.Debug(ctx, "v1ReqLoop exiting", slog.Error(err))
} else {
logger.Info(ctx, "v1ReqLoop failed to decode Node update", slog.Error(err))
}
return
}
logger.Debug(ctx, "v1ReqLoop got node update", slog.F("node", node))
pn, err := NodeToProto(&node)
if err != nil {
logger.Critical(ctx, "v1ReqLoop failed to convert v1 node", slog.F("node", node), slog.Error(err))
return
}
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
Node: pn,
}}
if err := SendCtx(ctx, reqs, req); err != nil {
logger.Debug(ctx, "v1ReqLoop ctx expired", slog.Error(err))
return
}
}
}
func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
defer func() { defer func() {
cErr := q.Close() cErr := q.Close()
if cErr != nil { if cErr != nil {
@ -1121,13 +968,13 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg
for { for {
resp, err := RecvCtx(ctx, resps) resp, err := RecvCtx(ctx, resps)
if err != nil { if err != nil {
logger.Debug(ctx, "v1RespLoop done reading responses", slog.Error(err)) logger.Debug(ctx, "qRespLoop done reading responses", slog.Error(err))
return return
} }
logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp)) logger.Debug(ctx, "qRespLoop got response", slog.F("resp", resp))
err = q.Enqueue(resp) err = q.Enqueue(resp)
if err != nil && !xerrors.Is(err, context.Canceled) { if err != nil && !xerrors.Is(err, context.Canceled) {
logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err)) logger.Error(ctx, "qRespLoop failed to enqueue v1 update", slog.Error(err))
} }
} }
} }

View File

@ -2,10 +2,7 @@ package tailnet_test
import ( import (
"context" "context"
"encoding/json"
"net" "net"
"net/http"
"net/http/httptest"
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -13,10 +10,8 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"nhooyr.io/websocket"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -34,162 +29,107 @@ func TestCoordinator(t *testing.T) {
t.Run("ClientWithoutAgent", func(t *testing.T) { t.Run("ClientWithoutAgent", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger) coordinator := tailnet.NewCoordinator(logger)
defer func() { defer func() {
err := coordinator.Close() err := coordinator.Close()
require.NoError(t, err) require.NoError(t, err)
}() }()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { client := test.NewClient(ctx, t, coordinator, "client", uuid.New())
return nil defer client.Close(ctx)
}) client.UpdateNode(&proto.Node{
id := uuid.New() Addresses: []string{netip.PrefixFrom(tailnet.IP(), 128).String()},
closeChan := make(chan struct{}) PreferredDerp: 10,
go func() {
err := coordinator.ServeClient(server, id, uuid.New())
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 128),
},
PreferredDERP: 10,
}) })
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
return coordinator.Node(id) != nil return coordinator.Node(client.ID) != nil
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, client.Close())
require.NoError(t, server.Close())
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
}) })
t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) { t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger) coordinator := tailnet.NewCoordinator(logger)
defer func() { defer func() {
err := coordinator.Close() err := coordinator.Close()
require.NoError(t, err) require.NoError(t, err)
}() }()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
return nil
})
id := uuid.New()
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(server, id, uuid.New())
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 64),
},
PreferredDERP: 10,
})
_ = testutil.RequireRecvCtx(ctx, t, errChan) client := test.NewClient(ctx, t, coordinator, "client", uuid.New())
_ = testutil.RequireRecvCtx(ctx, t, closeChan) defer client.Close(ctx)
client.UpdateNode(&proto.Node{
Addresses: []string{
netip.PrefixFrom(tailnet.IP(), 64).String(),
},
PreferredDerp: 10,
})
client.AssertEventuallyResponsesClosed()
}) })
t.Run("AgentWithoutClients", func(t *testing.T) { t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger) coordinator := tailnet.NewCoordinator(logger)
defer func() { defer func() {
err := coordinator.Close() err := coordinator.Close()
require.NoError(t, err) require.NoError(t, err)
}() }()
client, server := net.Pipe()
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { agent := test.NewAgent(ctx, t, coordinator, "agent")
return nil defer agent.Close(ctx)
}) agent.UpdateNode(&proto.Node{
id := uuid.New() Addresses: []string{
closeChan := make(chan struct{}) netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 128).String(),
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IPFromUUID(id), 128),
}, },
PreferredDERP: 10, PreferredDerp: 10,
}) })
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
return coordinator.Node(id) != nil return coordinator.Node(agent.ID) != nil
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
err := client.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, errChan)
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
}) })
t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) { t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger) coordinator := tailnet.NewCoordinator(logger)
defer func() { defer func() {
err := coordinator.Close() err := coordinator.Close()
require.NoError(t, err) require.NoError(t, err)
}() }()
client, server := net.Pipe() agent := test.NewAgent(ctx, t, coordinator, "agent")
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { defer agent.Close(ctx)
return nil agent.UpdateNode(&proto.Node{
}) Addresses: []string{
id := uuid.New() netip.PrefixFrom(tailnet.IP(), 128).String(),
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IP(), 128),
}, },
PreferredDERP: 10, PreferredDerp: 10,
}) })
_ = testutil.RequireRecvCtx(ctx, t, errChan) agent.AssertEventuallyResponsesClosed()
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
}) })
t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) { t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitShort)
coordinator := tailnet.NewCoordinator(logger) coordinator := tailnet.NewCoordinator(logger)
defer func() { defer func() {
err := coordinator.Close() err := coordinator.Close()
require.NoError(t, err) require.NoError(t, err)
}() }()
client, server := net.Pipe() agent := test.NewAgent(ctx, t, coordinator, "agent")
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { defer agent.Close(ctx)
return nil agent.UpdateNode(&proto.Node{
}) Addresses: []string{
id := uuid.New() netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 64).String(),
closeChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(server, id, "")
assert.NoError(t, err)
close(closeChan)
}()
sendNode(&tailnet.Node{
Addresses: []netip.Prefix{
netip.PrefixFrom(tailnet.IPFromUUID(id), 64),
}, },
PreferredDERP: 10, PreferredDerp: 10,
}) })
_ = testutil.RequireRecvCtx(ctx, t, errChan) agent.AssertEventuallyResponsesClosed()
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
}) })
t.Run("AgentWithClient", func(t *testing.T) { t.Run("AgentWithClient", func(t *testing.T) {
@ -201,180 +141,71 @@ func TestCoordinator(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
// in this test we use real websockets to test use of deadlines ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel() defer cancel()
agentWS, agentServerWS := websocketConn(ctx, t) agent := test.NewAgent(ctx, t, coordinator, "agent")
defer agentWS.Close() defer agent.Close(ctx)
agentNodeChan := make(chan []*tailnet.Node) agent.UpdateDERP(1)
sendAgentNode, agentErrChan := tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error {
agentNodeChan <- nodes
return nil
})
agentID := uuid.New()
closeAgentChan := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
close(closeAgentChan)
}()
sendAgentNode(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil return coordinator.Node(agent.ID) != nil
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := websocketConn(ctx, t) client := test.NewClient(ctx, t, coordinator, "client", agent.ID)
defer clientWS.Close() defer client.Close(ctx)
defer clientServerWS.Close() client.AssertEventuallyHasDERP(agent.ID, 1)
clientNodeChan := make(chan []*tailnet.Node)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
clientNodeChan <- nodes
return nil
})
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2}) client.UpdateDERP(2)
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan) agent.AssertEventuallyHasDERP(client.ID, 2)
require.Len(t, clientNodes, 1)
// wait longer than the internal wait timeout.
// this tests for regression of https://github.com/coder/coder/issues/7428
time.Sleep(tailnet.WriteTimeout * 3 / 2)
// Ensure an update to the agent node reaches the client! // Ensure an update to the agent node reaches the client!
sendAgentNode(&tailnet.Node{PreferredDERP: 3}) agent.UpdateDERP(3)
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan) client.AssertEventuallyHasDERP(agent.ID, 3)
require.Len(t, agentNodes, 1)
// Close the agent WebSocket so a new one can connect. // Close the agent so a new one can connect.
err := agentWS.Close() agent.Close(ctx)
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
// Create a new agent connection. This is to simulate a reconnect! // Create a new agent connection. This is to simulate a reconnect!
agentWS, agentServerWS = net.Pipe() agent = test.NewPeer(ctx, t, coordinator, "agent", test.WithID(agent.ID))
defer agentWS.Close() defer agent.Close(ctx)
agentNodeChan = make(chan []*tailnet.Node) // Ensure the agent gets the existing client node immediately!
_, agentErrChan = tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error { agent.AssertEventuallyHasDERP(client.ID, 2)
agentNodeChan <- nodes
return nil
})
closeAgentChan = make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
close(closeAgentChan)
}()
// Ensure the existing listening client sends its node immediately!
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan)
require.Len(t, clientNodes, 1)
err = agentWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
err = clientWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
}) })
t.Run("AgentDoubleConnect", func(t *testing.T) { t.Run("AgentDoubleConnect", func(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger) coordinator := tailnet.NewCoordinator(logger)
ctx := testutil.Context(t, testutil.WaitLong) ctx := testutil.Context(t, testutil.WaitShort)
agentWS1, agentServerWS1 := net.Pipe()
defer agentWS1.Close()
agentNodeChan1 := make(chan []*tailnet.Node)
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
t.Logf("agent1 got node update: %v", nodes)
agentNodeChan1 <- nodes
return nil
})
agentID := uuid.New() agentID := uuid.New()
closeAgentChan1 := make(chan struct{}) agent1 := test.NewPeer(ctx, t, coordinator, "agent1", test.WithID(agentID))
go func() { defer agent1.Close(ctx)
err := coordinator.ServeAgent(agentServerWS1, agentID, "") agent1.UpdateDERP(1)
assert.NoError(t, err)
close(closeAgentChan1)
}()
sendAgentNode1(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast) }, testutil.WaitShort, testutil.IntervalFast)
clientWS, clientServerWS := net.Pipe() client := test.NewPeer(ctx, t, coordinator, "client")
defer clientWS.Close() defer client.Close(ctx)
defer clientServerWS.Close() client.AddTunnel(agentID)
clientNodeChan := make(chan []*tailnet.Node) client.AssertEventuallyHasDERP(agent1.ID, 1)
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
t.Logf("client got node update: %v", nodes) client.UpdateDERP(2)
clientNodeChan <- nodes agent1.AssertEventuallyHasDERP(client.ID, 2)
return nil
})
clientID := uuid.New()
closeClientChan := make(chan struct{})
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
close(closeClientChan)
}()
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1)
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client! // Ensure an update to the agent node reaches the client!
sendAgentNode1(&tailnet.Node{PreferredDERP: 3}) agent1.UpdateDERP(3)
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan) client.AssertEventuallyHasDERP(agent1.ID, 3)
require.Len(t, agentNodes, 1)
// Create a new agent connection without disconnecting the old one. // Create a new agent connection without disconnecting the old one.
agentWS2, agentServerWS2 := net.Pipe() agent2 := test.NewPeer(ctx, t, coordinator, "agent2", test.WithID(agentID))
defer agentWS2.Close() defer agent2.Close(ctx)
agentNodeChan2 := make(chan []*tailnet.Node)
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
t.Logf("agent2 got node update: %v", nodes)
agentNodeChan2 <- nodes
return nil
})
closeAgentChan2 := make(chan struct{})
go func() {
err := coordinator.ServeAgent(agentServerWS2, agentID, "")
assert.NoError(t, err)
close(closeAgentChan2)
}()
// Ensure the existing listening client sends it's node immediately! // Ensure the existing client node gets sent immediately!
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2) agent2.AssertEventuallyHasDERP(client.ID, 2)
require.Len(t, clientNodes, 1)
// This original agent websocket should've been closed forcefully. // This original agent channels should've been closed forcefully.
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan1) agent1.AssertEventuallyResponsesClosed()
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1)
err := agentWS2.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan2)
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2)
err = clientWS.Close()
require.NoError(t, err)
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
}) })
t.Run("AgentAck", func(t *testing.T) { t.Run("AgentAck", func(t *testing.T) {
@ -396,89 +227,6 @@ func TestCoordinator(t *testing.T) {
}) })
} }
// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on
// https://github.com/coder/coder/issues/7295
func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator := tailnet.NewCoordinator(logger)
agentWS, agentServerWS := net.Pipe()
defer agentWS.Close()
agentID := uuid.New()
go func() {
err := coordinator.ServeAgent(agentServerWS, agentID, "")
assert.NoError(t, err)
}()
// send an agent update before the client connects so that there is
// node data available to send right away.
aNode := tailnet.Node{PreferredDERP: 0}
aData, err := json.Marshal(&aNode)
require.NoError(t, err)
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
_, err = agentWS.Write(aData)
require.NoError(t, err)
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
// Connect from the client
clientWS, clientServerWS := net.Pipe()
defer clientWS.Close()
clientID := uuid.New()
go func() {
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
assert.NoError(t, err)
}()
// peek one byte from the node update, so we know the coordinator is
// trying to write to the client.
// buffer needs to be 2 characters longer because return value is a list
// so, it needs [ and ]
buf := make([]byte, len(aData)+2)
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err := clientWS.Read(buf[:1])
require.NoError(t, err)
require.Equal(t, 1, n)
// send a second update
aNode.PreferredDERP = 1
require.NoError(t, err)
aData, err = json.Marshal(&aNode)
require.NoError(t, err)
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
_, err = agentWS.Write(aData)
require.NoError(t, err)
// read the rest of the update from the client, should be initial node.
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err = clientWS.Read(buf[1:])
require.NoError(t, err)
var cNodes []*tailnet.Node
err = json.Unmarshal(buf[:n+1], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 0, cNodes[0].PreferredDERP)
// read second update
// without a fix for https://github.com/coder/coder/issues/7295 our
// read would time out here.
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
n, err = clientWS.Read(buf)
require.NoError(t, err)
err = json.Unmarshal(buf[:n], &cNodes)
require.NoError(t, err)
require.Len(t, cNodes, 1)
require.Equal(t, 1, cNodes[0].PreferredDERP)
}
func TestCoordinator_BidirectionalTunnels(t *testing.T) { func TestCoordinator_BidirectionalTunnels(t *testing.T) {
t.Parallel() t.Parallel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
@ -521,29 +269,6 @@ func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
ma1.RequireEventuallyClosed(ctx) ma1.RequireEventuallyClosed(ctx)
} }
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
t.Helper()
sc := make(chan net.Conn, 1)
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
wss, err := websocket.Accept(rw, r, nil)
require.NoError(t, err)
conn := websocket.NetConn(r.Context(), wss, websocket.MessageBinary)
sc <- conn
close(sc) // there can be only one
// hold open until context canceled
<-ctx.Done()
}))
t.Cleanup(s.Close)
// nolint: bodyclose
wsc, _, err := websocket.Dial(ctx, s.URL, nil)
require.NoError(t, err)
client = websocket.NetConn(ctx, wsc, websocket.MessageBinary)
server, ok := <-sc
require.True(t, ok)
return client, server
}
func TestInMemoryCoordination(t *testing.T) { func TestInMemoryCoordination(t *testing.T) {
t.Parallel() t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)

View File

@ -9,4 +9,4 @@ const (
CurrentMinor = 2 CurrentMinor = 2
) )
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1) var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)

View File

@ -21,6 +21,8 @@ import (
"github.com/coder/quartz" "github.com/coder/quartz"
) )
var ErrUnsupportedVersion = xerrors.New("unsupported version")
type streamIDContextKey struct{} type streamIDContextKey struct{}
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this // StreamID identifies the caller of the CoordinateTailnet RPC. We store this
@ -47,7 +49,7 @@ type ClientServiceOptions struct {
} }
// ClientService is a tailnet coordination service that accepts a connection and version from a // ClientService is a tailnet coordination service that accepts a connection and version from a
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol. // tailnet client, and support versions 2.x of the Tailnet API protocol.
type ClientService struct { type ClientService struct {
Logger slog.Logger Logger slog.Logger
CoordPtr *atomic.Pointer[Coordinator] CoordPtr *atomic.Pointer[Coordinator]
@ -94,9 +96,6 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
return err return err
} }
switch major { switch major {
case 1:
coord := *(s.CoordPtr.Load())
return coord.ServeClient(conn, id, agent)
case 2: case 2:
auth := ClientCoordinateeAuth{AgentID: agent} auth := ClientCoordinateeAuth{AgentID: agent}
streamID := StreamID{ streamID := StreamID{
@ -107,7 +106,7 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
return s.ServeConnV2(ctx, conn, streamID) return s.ServeConnV2(ctx, conn, streamID)
default: default:
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
return xerrors.New("unsupported version") return ErrUnsupportedVersion
} }
} }

View File

@ -162,21 +162,8 @@ func TestClientService_ServeClient_V1(t *testing.T) {
errCh <- err errCh <- err
}() }()
call := testutil.RequireRecvCtx(ctx, t, fCoord.ServeClientCalls)
require.NotNil(t, call)
require.Equal(t, call.ID, clientID)
require.Equal(t, call.Agent, agentID)
require.Equal(t, s, call.Conn)
expectedError := xerrors.New("test error")
select {
case call.ErrCh <- expectedError:
// ok!
case <-ctx.Done():
t.Fatalf("timeout sending error")
}
err = testutil.RequireRecvCtx(ctx, t, errCh) err = testutil.RequireRecvCtx(ctx, t, errCh)
require.ErrorIs(t, err, expectedError) require.ErrorIs(t, err, tailnet.ErrUnsupportedVersion)
} }
func TestNetworkTelemetryBatcher(t *testing.T) { func TestNetworkTelemetryBatcher(t *testing.T) {

View File

@ -11,7 +11,6 @@ package tailnettest
import ( import (
context "context" context "context"
net "net"
http "net/http" http "net/http"
reflect "reflect" reflect "reflect"
@ -87,34 +86,6 @@ func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0)
} }
// ServeAgent mocks base method.
func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// ServeAgent indicates an expected call of ServeAgent.
func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2)
}
// ServeClient mocks base method.
func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// ServeClient indicates an expected call of ServeClient.
func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2)
}
// ServeHTTPDebug mocks base method. // ServeHTTPDebug mocks base method.
func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) { func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -161,7 +161,7 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
type TestMultiAgent struct { type TestMultiAgent struct {
t testing.TB t testing.TB
id uuid.UUID ID uuid.UUID
a tailnet.MultiAgentConn a tailnet.MultiAgentConn
nodeKey []byte nodeKey []byte
discoKey string discoKey string
@ -172,8 +172,8 @@ func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent
require.NoError(t, err) require.NoError(t, err)
dk, err := key.NewDisco().Public().MarshalText() dk, err := key.NewDisco().Public().MarshalText()
require.NoError(t, err) require.NoError(t, err)
m := &TestMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} m := &TestMultiAgent{t: t, ID: uuid.New(), nodeKey: nk, discoKey: string(dk)}
m.a = coord.ServeMultiAgent(m.id) m.a = coord.ServeMultiAgent(m.ID)
return m return m
} }
@ -277,8 +277,7 @@ func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) {
} }
type FakeCoordinator struct { type FakeCoordinator struct {
CoordinateCalls chan *FakeCoordinate CoordinateCalls chan *FakeCoordinate
ServeClientCalls chan *FakeServeClient
} }
func (*FakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) { func (*FakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) {
@ -289,21 +288,6 @@ func (*FakeCoordinator) Node(uuid.UUID) *tailnet.Node {
panic("unimplemented") panic("unimplemented")
} }
func (f *FakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
errCh := make(chan error)
f.ServeClientCalls <- &FakeServeClient{
Conn: conn,
ID: id,
Agent: agent,
ErrCh: errCh,
}
return <-errCh
}
func (*FakeCoordinator) ServeAgent(net.Conn, uuid.UUID, string) error {
panic("unimplemented")
}
func (*FakeCoordinator) Close() error { func (*FakeCoordinator) Close() error {
panic("unimplemented") panic("unimplemented")
} }
@ -328,8 +312,7 @@ func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name str
func NewFakeCoordinator() *FakeCoordinator { func NewFakeCoordinator() *FakeCoordinator {
return &FakeCoordinator{ return &FakeCoordinator{
CoordinateCalls: make(chan *FakeCoordinate, 100), CoordinateCalls: make(chan *FakeCoordinate, 100),
ServeClientCalls: make(chan *FakeServeClient, 100),
} }
} }

View File

@ -3,10 +3,12 @@ package test
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"tailscale.com/types/key"
"github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/proto"
@ -18,43 +20,73 @@ type PeerStatus struct {
readyForHandshake bool readyForHandshake bool
} }
type PeerOption func(*Peer)
func WithID(id uuid.UUID) PeerOption {
return func(p *Peer) {
p.ID = id
}
}
func WithAuth(auth tailnet.CoordinateeAuth) PeerOption {
return func(p *Peer) {
p.auth = auth
}
}
type Peer struct { type Peer struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
t testing.TB t testing.TB
ID uuid.UUID ID uuid.UUID
auth tailnet.CoordinateeAuth
name string name string
nodeKey key.NodePublic
discoKey key.DiscoPublic
resps <-chan *proto.CoordinateResponse resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest reqs chan<- *proto.CoordinateRequest
peers map[uuid.UUID]PeerStatus peers map[uuid.UUID]PeerStatus
peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate
} }
func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer { func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, opts ...PeerOption) *Peer {
p := &Peer{ p := &Peer{
t: t, t: t,
name: name, name: name,
peers: make(map[uuid.UUID]PeerStatus), peers: make(map[uuid.UUID]PeerStatus),
peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate), peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate),
ID: uuid.New(),
// SingleTailnetCoordinateeAuth allows connections to arbitrary peers
auth: tailnet.SingleTailnetCoordinateeAuth{},
// required for converting to and from protobuf, so we always include them
nodeKey: key.NewNode().Public(),
discoKey: key.NewDisco().Public(),
} }
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
if len(id) > 1 { for _, opt := range opts {
t.Fatal("too many") opt(p)
} }
if len(id) == 1 {
p.ID = id[0] p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, p.auth)
} else { return p
p.ID = uuid.New() }
}
// SingleTailnetTunnelAuth allows connections to arbitrary peers // NewAgent is a wrapper around NewPeer, creating a peer with Agent auth tied to its ID
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetCoordinateeAuth{}) func NewAgent(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string) *Peer {
id := uuid.New()
return NewPeer(ctx, t, coord, name, WithID(id), WithAuth(tailnet.AgentCoordinateeAuth{ID: id}))
}
// NewClient is a wrapper around NewPeer, creating a peer with Client auth tied to the provided agentID
func NewClient(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, agentID uuid.UUID) *Peer {
p := NewPeer(ctx, t, coord, name, WithAuth(tailnet.ClientCoordinateeAuth{AgentID: agentID}))
p.AddTunnel(agentID)
return p return p
} }
func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) { func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) {
p.t.Helper() p.t.Helper()
p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, p.auth)
p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, tailnet.SingleTailnetCoordinateeAuth{})
} }
func (p *Peer) AddTunnel(other uuid.UUID) { func (p *Peer) AddTunnel(other uuid.UUID) {
@ -71,7 +103,19 @@ func (p *Peer) AddTunnel(other uuid.UUID) {
func (p *Peer) UpdateDERP(derp int32) { func (p *Peer) UpdateDERP(derp int32) {
p.t.Helper() p.t.Helper()
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}} node := &proto.Node{PreferredDerp: derp}
p.UpdateNode(node)
}
func (p *Peer) UpdateNode(node *proto.Node) {
p.t.Helper()
nk, err := p.nodeKey.MarshalBinary()
assert.NoError(p.t, err)
node.Key = nk
dk, err := p.discoKey.MarshalText()
assert.NoError(p.t, err)
node.Disco = string(dk)
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}}
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name) p.t.Errorf("timeout updating node for %s", p.name)
@ -115,13 +159,37 @@ func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) {
if ok && o.preferredDERP == derp { if ok && o.preferredDERP == derp {
return return
} }
if err := p.handleOneResp(); err != nil { if err := p.readOneResp(); err != nil {
assert.NoError(p.t, err) assert.NoError(p.t, err)
return return
} }
} }
} }
func (p *Peer) AssertNeverHasDERPs(ctx context.Context, other uuid.UUID, expected ...int32) {
p.t.Helper()
for {
select {
case <-ctx.Done():
return
case resp, ok := <-p.resps:
if !ok {
p.t.Errorf("response channel closed")
}
if !assert.NoError(p.t, p.handleResp(resp)) {
return
}
derp, ok := p.peers[other]
if !ok {
continue
}
if !assert.NotContains(p.t, expected, derp) {
return
}
}
}
}
func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) { func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
p.t.Helper() p.t.Helper()
for { for {
@ -129,7 +197,7 @@ func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
if !ok { if !ok {
return return
} }
if err := p.handleOneResp(); err != nil { if err := p.readOneResp(); err != nil {
assert.NoError(p.t, err) assert.NoError(p.t, err)
return return
} }
@ -143,7 +211,7 @@ func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
if o.status == proto.CoordinateResponse_PeerUpdate_LOST { if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
return return
} }
if err := p.handleOneResp(); err != nil { if err := p.readOneResp(); err != nil {
assert.NoError(p.t, err) assert.NoError(p.t, err)
return return
} }
@ -153,7 +221,7 @@ func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
func (p *Peer) AssertEventuallyResponsesClosed() { func (p *Peer) AssertEventuallyResponsesClosed() {
p.t.Helper() p.t.Helper()
for { for {
err := p.handleOneResp() err := p.readOneResp()
if xerrors.Is(err, responsesClosed) { if xerrors.Is(err, responsesClosed) {
return return
} }
@ -163,6 +231,32 @@ func (p *Peer) AssertEventuallyResponsesClosed() {
} }
} }
func (p *Peer) AssertNotClosed(d time.Duration) {
p.t.Helper()
// nolint: gocritic // erroneously thinks we're hardcoding non testutil constants here
ctx, cancel := context.WithTimeout(context.Background(), d)
defer cancel()
for {
select {
case <-ctx.Done():
// success!
return
case <-p.ctx.Done():
p.t.Error("main ctx timeout before elapsed time")
return
case resp, ok := <-p.resps:
if !ok {
p.t.Error("response channel closed")
return
}
err := p.handleResp(resp)
if !assert.NoError(p.t, err) {
return
}
}
}
}
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) { func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
p.t.Helper() p.t.Helper()
for { for {
@ -171,7 +265,7 @@ func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
return return
} }
err := p.handleOneResp() err := p.readOneResp()
if xerrors.Is(err, responsesClosed) { if xerrors.Is(err, responsesClosed) {
return return
} }
@ -181,8 +275,9 @@ func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
func (p *Peer) AssertEventuallyGetsError(match string) { func (p *Peer) AssertEventuallyGetsError(match string) {
p.t.Helper() p.t.Helper()
for { for {
err := p.handleOneResp() err := p.readOneResp()
if xerrors.Is(err, responsesClosed) { if xerrors.Is(err, responsesClosed) {
p.t.Error("closed before target error")
return return
} }
@ -207,7 +302,7 @@ func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateRespon
var responsesClosed = xerrors.New("responses closed") var responsesClosed = xerrors.New("responses closed")
func (p *Peer) handleOneResp() error { func (p *Peer) readOneResp() error {
select { select {
case <-p.ctx.Done(): case <-p.ctx.Done():
return p.ctx.Err() return p.ctx.Err()
@ -215,31 +310,39 @@ func (p *Peer) handleOneResp() error {
if !ok { if !ok {
return responsesClosed return responsesClosed
} }
if resp.Error != "" { err := p.handleResp(resp)
return xerrors.New(resp.Error) if err != nil {
return err
} }
for _, update := range resp.PeerUpdates { }
id, err := uuid.FromBytes(update.Id) return nil
if err != nil { }
return err
}
p.peerUpdates[id] = append(p.peerUpdates[id], update)
switch update.Kind { func (p *Peer) handleResp(resp *proto.CoordinateResponse) error {
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST: if resp.Error != "" {
peer := p.peers[id] return xerrors.New(resp.Error)
peer.preferredDERP = update.GetNode().GetPreferredDerp() }
peer.status = update.Kind for _, update := range resp.PeerUpdates {
p.peers[id] = peer id, err := uuid.FromBytes(update.Id)
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: if err != nil {
delete(p.peers, id) return err
case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE: }
peer := p.peers[id] p.peerUpdates[id] = append(p.peerUpdates[id], update)
peer.readyForHandshake = true
p.peers[id] = peer switch update.Kind {
default: case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
return xerrors.Errorf("unhandled update kind %s", update.Kind) peer := p.peers[id]
} peer.preferredDERP = update.GetNode().GetPreferredDerp()
peer.status = update.Kind
p.peers[id] = peer
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.peers, id)
case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
peer := p.peers[id]
peer.readyForHandshake = true
p.peers[id] = peer
default:
return xerrors.Errorf("unhandled update kind %s", update.Kind)
} }
} }
return nil return nil
@ -261,3 +364,9 @@ func (p *Peer) Close(ctx context.Context) {
} }
} }
} }
func (p *Peer) UngracefulDisconnect(ctx context.Context) {
p.t.Helper()
close(p.reqs)
p.Close(ctx)
}

View File

@ -1,182 +0,0 @@
package tailnet
import (
"bytes"
"context"
"encoding/json"
"net"
"sync/atomic"
"time"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet/proto"
)
const (
// WriteTimeout is the amount of time we wait to write a node update to a connection before we
// declare it hung. It is exported so that tests can use it.
WriteTimeout = time.Second * 5
// ResponseBufferSize is the max number of responses to buffer per connection before we start
// dropping updates
ResponseBufferSize = 512
// RequestBufferSize is the max number of requests to buffer per connection
RequestBufferSize = 32
)
type TrackedConn struct {
ctx context.Context
cancel func()
kind QueueKind
conn net.Conn
updates chan *proto.CoordinateResponse
logger slog.Logger
lastData []byte
// ID is an ephemeral UUID used to uniquely identify the owner of the
// connection.
id uuid.UUID
name string
start int64
lastWrite int64
overwrites int64
}
func NewTrackedConn(ctx context.Context, cancel func(),
conn net.Conn,
id uuid.UUID,
logger slog.Logger,
name string,
overwrites int64,
kind QueueKind,
) *TrackedConn {
// buffer updates so they don't block, since we hold the
// coordinator mutex while queuing. Node updates don't
// come quickly, so 512 should be plenty for all but
// the most pathological cases.
updates := make(chan *proto.CoordinateResponse, ResponseBufferSize)
now := time.Now().Unix()
return &TrackedConn{
ctx: ctx,
conn: conn,
cancel: cancel,
updates: updates,
logger: logger,
id: id,
start: now,
lastWrite: now,
name: name,
overwrites: overwrites,
kind: kind,
}
}
func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) {
atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
select {
case t.updates <- resp:
return nil
default:
return ErrWouldBlock
}
}
func (t *TrackedConn) UniqueID() uuid.UUID {
return t.id
}
func (t *TrackedConn) Kind() QueueKind {
return t.kind
}
func (t *TrackedConn) Name() string {
return t.name
}
func (t *TrackedConn) Stats() (start, lastWrite int64) {
return t.start, atomic.LoadInt64(&t.lastWrite)
}
func (t *TrackedConn) Overwrites() int64 {
return t.overwrites
}
func (t *TrackedConn) CoordinatorClose() error {
return t.Close()
}
func (t *TrackedConn) Done() <-chan struct{} {
return t.ctx.Done()
}
// Close the connection and cancel the context for reading node updates from the queue
func (t *TrackedConn) Close() error {
t.cancel()
return t.conn.Close()
}
// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is
// canceled.
func (t *TrackedConn) SendUpdates() {
for {
select {
case <-t.ctx.Done():
t.logger.Debug(t.ctx, "done sending updates")
return
case resp := <-t.updates:
nodes, err := OnlyNodeUpdates(resp)
if err != nil {
t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err))
return
}
if len(nodes) == 0 {
t.logger.Debug(t.ctx, "skipping response with no nodes")
continue
}
data, err := json.Marshal(nodes)
if err != nil {
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
return
}
if bytes.Equal(t.lastData, data) {
t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data)))
continue
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout))
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err))
_ = t.Close()
return
}
_, err = t.conn.Write(data)
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "could not write nodes to connection",
slog.Error(err), slog.F("nodes", string(data)))
_ = t.Close()
return
}
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data)))
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
// our successful write, it is important that we reset the deadline before it fires.
err = t.conn.SetWriteDeadline(time.Time{})
if err != nil {
// often, this is just because the connection is closed/broken, so only log at debug.
t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err))
_ = t.Close()
return
}
t.lastData = data
}
}
}