diff --git a/enterprise/tailnet/multiagent_test.go b/enterprise/tailnet/multiagent_test.go index bbb3c55735..288c4fb86f 100644 --- a/enterprise/tailnet/multiagent_test.go +++ b/enterprise/tailnet/multiagent_test.go @@ -10,8 +10,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/enterprise/tailnet" - agpl "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/tailnettest" + agpltest "github.com/coder/coder/v2/tailnet/test" "github.com/coder/coder/v2/testutil" ) @@ -35,27 +35,27 @@ func TestPGCoordinator_MultiAgent(t *testing.T) { require.NoError(t, err) defer coord1.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + agent1.UpdateDERP(5) ma1 := tailnettest.NewTestMultiAgent(t, coord1) defer ma1.Close() - ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireSubscribeAgent(agent1.ID) ma1.RequireEventuallyHasDERPs(ctx, 5) - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + agent1.UpdateDERP(1) ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.SendNodeWithDERP(3) - assertEventuallyHasDERPs(ctx, t, agent1, 3) + agent1.AssertEventuallyHasDERP(ma1.ID, 3) ma1.Close() - require.NoError(t, agent1.close()) + agent1.UngracefulDisconnect(ctx) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID) + assertEventuallyLost(ctx, t, store, agent1.ID) } func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) { @@ -102,28 +102,28 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) { require.NoError(t, err) defer coord1.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + agent1.UpdateDERP(5) ma1 := tailnettest.NewTestMultiAgent(t, coord1) defer ma1.Close() - ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireSubscribeAgent(agent1.ID) ma1.RequireEventuallyHasDERPs(ctx, 5) - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + agent1.UpdateDERP(1) ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.SendNodeWithDERP(3) - assertEventuallyHasDERPs(ctx, t, agent1, 3) + agent1.AssertEventuallyHasDERP(ma1.ID, 3) - ma1.RequireUnsubscribeAgent(agent1.id) + ma1.RequireUnsubscribeAgent(agent1.ID) ma1.Close() - require.NoError(t, agent1.close()) + agent1.UngracefulDisconnect(ctx) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID) + assertEventuallyLost(ctx, t, store, agent1.ID) } // TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a @@ -147,43 +147,43 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) { require.NoError(t, err) defer coord1.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + agent1.UpdateDERP(5) ma1 := tailnettest.NewTestMultiAgent(t, coord1) defer ma1.Close() - ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireSubscribeAgent(agent1.ID) ma1.RequireEventuallyHasDERPs(ctx, 5) - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + agent1.UpdateDERP(1) ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.SendNodeWithDERP(3) - assertEventuallyHasDERPs(ctx, t, agent1, 3) + agent1.AssertEventuallyHasDERP(ma1.ID, 3) - ma1.RequireUnsubscribeAgent(agent1.id) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + ma1.RequireUnsubscribeAgent(agent1.ID) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID) func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() ma1.SendNodeWithDERP(9) - assertNeverHasDERPs(ctx, t, agent1, 9) + agent1.AssertNeverHasDERPs(ctx, ma1.ID, 9) }() func() { ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3) defer cancel() - agent1.sendNode(&agpl.Node{PreferredDERP: 8}) + agent1.UpdateDERP(8) ma1.RequireNeverHasDERPs(ctx, 8) }() ma1.Close() - require.NoError(t, agent1.close()) + agent1.UngracefulDisconnect(ctx) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID) + assertEventuallyLost(ctx, t, store, agent1.ID) } // TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a @@ -212,27 +212,27 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) { require.NoError(t, err) defer coord2.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + agent1.UpdateDERP(5) ma1 := tailnettest.NewTestMultiAgent(t, coord2) defer ma1.Close() - ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireSubscribeAgent(agent1.ID) ma1.RequireEventuallyHasDERPs(ctx, 5) - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + agent1.UpdateDERP(1) ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.SendNodeWithDERP(3) - assertEventuallyHasDERPs(ctx, t, agent1, 3) + agent1.AssertEventuallyHasDERP(ma1.ID, 3) ma1.Close() - require.NoError(t, agent1.close()) + agent1.UngracefulDisconnect(ctx) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID) + assertEventuallyLost(ctx, t, store, agent1.ID) } // TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two @@ -262,27 +262,27 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test require.NoError(t, err) defer coord2.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + agent1.UpdateDERP(5) ma1 := tailnettest.NewTestMultiAgent(t, coord2) defer ma1.Close() ma1.SendNodeWithDERP(3) - ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireSubscribeAgent(agent1.ID) ma1.RequireEventuallyHasDERPs(ctx, 5) - assertEventuallyHasDERPs(ctx, t, agent1, 3) + agent1.AssertEventuallyHasDERP(ma1.ID, 3) - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + agent1.UpdateDERP(1) ma1.RequireEventuallyHasDERPs(ctx, 1) ma1.Close() - require.NoError(t, agent1.close()) + agent1.UngracefulDisconnect(ctx) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID) + assertEventuallyLost(ctx, t, store, agent1.ID) } // TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a @@ -317,37 +317,37 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) { require.NoError(t, err) defer coord3.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent1.sendNode(&agpl.Node{PreferredDERP: 5}) + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + agent1.UpdateDERP(5) - agent2 := newTestAgent(t, coord2, "agent2") - defer agent1.close() - agent2.sendNode(&agpl.Node{PreferredDERP: 6}) + agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2") + defer agent2.Close(ctx) + agent2.UpdateDERP(6) ma1 := tailnettest.NewTestMultiAgent(t, coord3) defer ma1.Close() - ma1.RequireSubscribeAgent(agent1.id) + ma1.RequireSubscribeAgent(agent1.ID) ma1.RequireEventuallyHasDERPs(ctx, 5) - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + agent1.UpdateDERP(1) ma1.RequireEventuallyHasDERPs(ctx, 1) - ma1.RequireSubscribeAgent(agent2.id) + ma1.RequireSubscribeAgent(agent2.ID) ma1.RequireEventuallyHasDERPs(ctx, 6) - agent2.sendNode(&agpl.Node{PreferredDERP: 2}) + agent2.UpdateDERP(2) ma1.RequireEventuallyHasDERPs(ctx, 2) ma1.SendNodeWithDERP(3) - assertEventuallyHasDERPs(ctx, t, agent1, 3) - assertEventuallyHasDERPs(ctx, t, agent2, 3) + agent1.AssertEventuallyHasDERP(ma1.ID, 3) + agent2.AssertEventuallyHasDERP(ma1.ID, 3) ma1.Close() - require.NoError(t, agent1.close()) - require.NoError(t, agent2.close()) + agent1.UngracefulDisconnect(ctx) + agent2.UngracefulDisconnect(ctx) - assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) - assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID) + assertEventuallyLost(ctx, t, store, agent1.ID) } diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index be4722a02f..f8530ca990 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -3,7 +3,6 @@ package tailnet import ( "context" "database/sql" - "net" "strings" "sync" "sync/atomic" @@ -213,14 +212,6 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { return node } -func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { - return agpl.ServeClientV1(c.ctx, c.logger, c, conn, id, agent) -} - -func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { - return agpl.ServeAgentV1(c.ctx, c.logger, c, conn, id, name) -} - func (c *pgCoord) Close() error { c.logger.Info(c.ctx, "closing coordinator") c.cancel() diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index 118fdead98..dc9b4e2806 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -3,8 +3,6 @@ package tailnet_test import ( "context" "database/sql" - "io" - "net" "net/netip" "sync" "testing" @@ -15,7 +13,6 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" "go.uber.org/mock/gomock" - "golang.org/x/exp/slices" "golang.org/x/xerrors" gProto "google.golang.org/protobuf/proto" @@ -51,9 +48,9 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { defer coordinator.Close() agentID := uuid.New() - client := newTestClient(t, coordinator, agentID) - defer client.close() - client.sendNode(&agpl.Node{PreferredDERP: 10}) + client := agpltest.NewClient(ctx, t, coordinator, "client", agentID) + defer client.Close(ctx) + client.UpdateDERP(10) require.Eventually(t, func() bool { clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { @@ -68,12 +65,8 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) - - err = client.close() - require.NoError(t, err) - <-client.errChan - <-client.closeChan - assertEventuallyLost(ctx, t, store, client.id) + client.UngracefulDisconnect(ctx) + assertEventuallyLost(ctx, t, store, client.ID) } func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { @@ -89,11 +82,11 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator, "agent") - defer agent.close() - agent.sendNode(&agpl.Node{PreferredDERP: 10}) + agent := agpltest.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateDERP(10) require.Eventually(t, func() bool { - agents, err := store.GetTailnetPeers(ctx, agent.id) + agents, err := store.GetTailnetPeers(ctx, agent.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { t.Fatalf("database error: %v", err) } @@ -106,11 +99,8 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) - err = agent.close() - require.NoError(t, err) - <-agent.errChan - <-agent.closeChan - assertEventuallyLost(ctx, t, store, agent.id) + agent.UngracefulDisconnect(ctx) + assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) { @@ -126,18 +116,18 @@ func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator, "agent") - defer agent.close() - agent.sendNode(&agpl.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(agpl.IP(), 128), + agent := agpltest.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateNode(&proto.Node{ + Addresses: []string{ + netip.PrefixFrom(agpl.IP(), 128).String(), }, - PreferredDERP: 10, + PreferredDerp: 10, }) // The agent connection should be closed immediately after sending an invalid addr - testutil.RequireRecvCtx(ctx, t, agent.closeChan) - assertEventuallyLost(ctx, t, store, agent.id) + agent.AssertEventuallyResponsesClosed() + assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) { @@ -153,18 +143,18 @@ func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator, "agent") - defer agent.close() - agent.sendNode(&agpl.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(agpl.IPFromUUID(agent.id), 64), + agent := agpltest.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateNode(&proto.Node{ + Addresses: []string{ + netip.PrefixFrom(agpl.IPFromUUID(agent.ID), 64).String(), }, - PreferredDERP: 10, + PreferredDerp: 10, }) // The agent connection should be closed immediately after sending an invalid addr - testutil.RequireRecvCtx(ctx, t, agent.closeChan) - assertEventuallyLost(ctx, t, store, agent.id) + agent.AssertEventuallyResponsesClosed() + assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) { @@ -180,16 +170,16 @@ func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator, "agent") - defer agent.close() - agent.sendNode(&agpl.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(agpl.IPFromUUID(agent.id), 128), + agent := agpltest.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateNode(&proto.Node{ + Addresses: []string{ + netip.PrefixFrom(agpl.IPFromUUID(agent.ID), 128).String(), }, - PreferredDERP: 10, + PreferredDerp: 10, }) require.Eventually(t, func() bool { - agents, err := store.GetTailnetPeers(ctx, agent.id) + agents, err := store.GetTailnetPeers(ctx, agent.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { t.Fatalf("database error: %v", err) } @@ -202,11 +192,8 @@ func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) { assert.EqualValues(t, 10, node.PreferredDerp) return true }, testutil.WaitShort, testutil.IntervalFast) - err = agent.close() - require.NoError(t, err) - <-agent.errChan - <-agent.closeChan - assertEventuallyLost(ctx, t, store, agent.id) + agent.UngracefulDisconnect(ctx) + assertEventuallyLost(ctx, t, store, agent.ID) } func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { @@ -222,68 +209,40 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator, "original") - defer agent.close() - agent.sendNode(&agpl.Node{PreferredDERP: 10}) + agent := agpltest.NewAgent(ctx, t, coordinator, "original") + defer agent.Close(ctx) + agent.UpdateDERP(10) - client := newTestClient(t, coordinator, agent.id) - defer client.close() + client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID) + defer client.Close(ctx) - agentNodes := client.recvNodes(ctx, t) - require.Len(t, agentNodes, 1) - assert.Equal(t, 10, agentNodes[0].PreferredDERP) - client.sendNode(&agpl.Node{PreferredDERP: 11}) - clientNodes := agent.recvNodes(ctx, t) - require.Len(t, clientNodes, 1) - assert.Equal(t, 11, clientNodes[0].PreferredDERP) + client.AssertEventuallyHasDERP(agent.ID, 10) + client.UpdateDERP(11) + agent.AssertEventuallyHasDERP(client.ID, 11) // Ensure an update to the agent node reaches the connIO! - agent.sendNode(&agpl.Node{PreferredDERP: 12}) - agentNodes = client.recvNodes(ctx, t) - require.Len(t, agentNodes, 1) - assert.Equal(t, 12, agentNodes[0].PreferredDERP) + agent.UpdateDERP(12) + client.AssertEventuallyHasDERP(agent.ID, 12) - // Close the agent WebSocket so a new one can connect. - err = agent.close() - require.NoError(t, err) - _ = agent.recvErr(ctx, t) - agent.waitForClose(ctx, t) + // Close the agent channel so a new one can connect. + agent.Close(ctx) // Create a new agent connection. This is to simulate a reconnect! - agent = newTestAgent(t, coordinator, "reconnection", agent.id) - // Ensure the existing listening connIO sends its node immediately! - clientNodes = agent.recvNodes(ctx, t) - require.Len(t, clientNodes, 1) - assert.Equal(t, 11, clientNodes[0].PreferredDERP) + agent = agpltest.NewPeer(ctx, t, coordinator, "reconnection", agpltest.WithID(agent.ID)) + // Ensure the coordinator sends its client node immediately! + agent.AssertEventuallyHasDERP(client.ID, 11) // Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the // coordinator accidentally reordering things. - for d := 13; d < 36; d++ { - agent.sendNode(&agpl.Node{PreferredDERP: d}) - } - for { - nodes := client.recvNodes(ctx, t) - if !assert.Len(t, nodes, 1) { - break - } - if nodes[0].PreferredDERP == 35 { - // got latest! - break - } + for d := int32(13); d < 36; d++ { + agent.UpdateDERP(d) } + client.AssertEventuallyHasDERP(agent.ID, 35) - err = agent.close() - require.NoError(t, err) - _ = agent.recvErr(ctx, t) - agent.waitForClose(ctx, t) - - err = client.close() - require.NoError(t, err) - _ = client.recvErr(ctx, t) - client.waitForClose(ctx, t) - - assertEventuallyLost(ctx, t, store, agent.id) - assertEventuallyLost(ctx, t, store, client.id) + agent.UngracefulDisconnect(ctx) + client.UngracefulDisconnect(ctx) + assertEventuallyLost(ctx, t, store, agent.ID) + assertEventuallyLost(ctx, t, store, client.ID) } func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { @@ -305,16 +264,16 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator, "agent") - defer agent.close() - agent.sendNode(&agpl.Node{PreferredDERP: 10}) + agent := agpltest.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateDERP(10) - client := newTestClient(t, coordinator, agent.id) - defer client.close() + client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID) + defer client.Close(ctx) - assertEventuallyHasDERPs(ctx, t, client, 10) - client.sendNode(&agpl.Node{PreferredDERP: 11}) - assertEventuallyHasDERPs(ctx, t, agent, 11) + client.AssertEventuallyHasDERP(agent.ID, 10) + client.UpdateDERP(11) + agent.AssertEventuallyHasDERP(client.ID, 11) // simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a // real coordinator @@ -328,8 +287,8 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { fCoord2.heartbeat() afTrap.MustWait(ctx).Release() // heartbeat timeout started - fCoord2.agentNode(agent.id, &agpl.Node{PreferredDERP: 12}) - assertEventuallyHasDERPs(ctx, t, client, 12) + fCoord2.agentNode(agent.ID, &agpl.Node{PreferredDERP: 12}) + client.AssertEventuallyHasDERP(agent.ID, 12) fCoord3 := &fakeCoordinator{ ctx: ctx, @@ -339,8 +298,8 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { } fCoord3.heartbeat() rstTrap.MustWait(ctx).Release() // timeout gets reset - fCoord3.agentNode(agent.id, &agpl.Node{PreferredDERP: 13}) - assertEventuallyHasDERPs(ctx, t, client, 13) + fCoord3.agentNode(agent.ID, &agpl.Node{PreferredDERP: 13}) + client.AssertEventuallyHasDERP(agent.ID, 13) // fCoord2 sends in a second heartbeat, one period later (on time) mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) @@ -353,30 +312,22 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { w := mClock.Advance(tailnet.HeartbeatPeriod) rstTrap.MustWait(ctx).Release() w.MustWait(ctx) - assertEventuallyHasDERPs(ctx, t, client, 12) + client.AssertEventuallyHasDERP(agent.ID, 12) // one more heartbeat period will result in fCoord2 being expired, which should cause us to // revert to the original agent mapping mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx) // note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired - assertEventuallyHasDERPs(ctx, t, client, 10) + client.AssertEventuallyHasDERP(agent.ID, 10) // send fCoord3 heartbeat, which should trigger us to consider that mapping valid again. fCoord3.heartbeat() rstTrap.MustWait(ctx).Release() // timeout gets reset - assertEventuallyHasDERPs(ctx, t, client, 13) + client.AssertEventuallyHasDERP(agent.ID, 13) - err = agent.close() - require.NoError(t, err) - _ = agent.recvErr(ctx, t) - agent.waitForClose(ctx, t) - - err = client.close() - require.NoError(t, err) - _ = client.recvErr(ctx, t) - client.waitForClose(ctx, t) - - assertEventuallyLost(ctx, t, store, client.id) + agent.UngracefulDisconnect(ctx) + client.UngracefulDisconnect(ctx) + assertEventuallyLost(ctx, t, store, client.ID) } func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) { @@ -420,7 +371,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) { // disconnect. client.AssertEventuallyLost(agentID) - client.Close(ctx) + client.UngracefulDisconnect(ctx) assertEventuallyLost(ctx, t, store, client.ID) } @@ -491,104 +442,73 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) { require.NoError(t, err) defer coord2.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - t.Logf("agent1=%s", agent1.id) - agent2 := newTestAgent(t, coord2, "agent2") - defer agent2.close() - t.Logf("agent2=%s", agent2.id) + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + t.Logf("agent1=%s", agent1.ID) + agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2") + defer agent2.Close(ctx) + t.Logf("agent2=%s", agent2.ID) - client11 := newTestClient(t, coord1, agent1.id) - defer client11.close() - t.Logf("client11=%s", client11.id) - client12 := newTestClient(t, coord1, agent2.id) - defer client12.close() - t.Logf("client12=%s", client12.id) - client21 := newTestClient(t, coord2, agent1.id) - defer client21.close() - t.Logf("client21=%s", client21.id) - client22 := newTestClient(t, coord2, agent2.id) - defer client22.close() - t.Logf("client22=%s", client22.id) + client11 := agpltest.NewClient(ctx, t, coord1, "client11", agent1.ID) + defer client11.Close(ctx) + t.Logf("client11=%s", client11.ID) + client12 := agpltest.NewClient(ctx, t, coord1, "client12", agent2.ID) + defer client12.Close(ctx) + t.Logf("client12=%s", client12.ID) + client21 := agpltest.NewClient(ctx, t, coord2, "client21", agent1.ID) + defer client21.Close(ctx) + t.Logf("client21=%s", client21.ID) + client22 := agpltest.NewClient(ctx, t, coord2, "client22", agent2.ID) + defer client22.Close(ctx) + t.Logf("client22=%s", client22.ID) t.Logf("client11 -> Node 11") - client11.sendNode(&agpl.Node{PreferredDERP: 11}) - assertEventuallyHasDERPs(ctx, t, agent1, 11) + client11.UpdateDERP(11) + agent1.AssertEventuallyHasDERP(client11.ID, 11) t.Logf("client21 -> Node 21") - client21.sendNode(&agpl.Node{PreferredDERP: 21}) - assertEventuallyHasDERPs(ctx, t, agent1, 21) + client21.UpdateDERP(21) + agent1.AssertEventuallyHasDERP(client21.ID, 21) t.Logf("client22 -> Node 22") - client22.sendNode(&agpl.Node{PreferredDERP: 22}) - assertEventuallyHasDERPs(ctx, t, agent2, 22) + client22.UpdateDERP(22) + agent2.AssertEventuallyHasDERP(client22.ID, 22) t.Logf("agent2 -> Node 2") - agent2.sendNode(&agpl.Node{PreferredDERP: 2}) - assertEventuallyHasDERPs(ctx, t, client22, 2) - assertEventuallyHasDERPs(ctx, t, client12, 2) + agent2.UpdateDERP(2) + client22.AssertEventuallyHasDERP(agent2.ID, 2) + client12.AssertEventuallyHasDERP(agent2.ID, 2) t.Logf("client12 -> Node 12") - client12.sendNode(&agpl.Node{PreferredDERP: 12}) - assertEventuallyHasDERPs(ctx, t, agent2, 12) + client12.UpdateDERP(12) + agent2.AssertEventuallyHasDERP(client12.ID, 12) t.Logf("agent1 -> Node 1") - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertEventuallyHasDERPs(ctx, t, client21, 1) - assertEventuallyHasDERPs(ctx, t, client11, 1) + agent1.UpdateDERP(1) + client21.AssertEventuallyHasDERP(agent1.ID, 1) + client11.AssertEventuallyHasDERP(agent1.ID, 1) t.Logf("close coord2") err = coord2.Close() require.NoError(t, err) // this closes agent2, client22, client21 - err = agent2.recvErr(ctx, t) - require.ErrorIs(t, err, io.EOF) - err = client22.recvErr(ctx, t) - require.ErrorIs(t, err, io.EOF) - err = client21.recvErr(ctx, t) - require.ErrorIs(t, err, io.EOF) - assertEventuallyLost(ctx, t, store, agent2.id) - assertEventuallyLost(ctx, t, store, client21.id) - assertEventuallyLost(ctx, t, store, client22.id) + agent2.AssertEventuallyResponsesClosed() + client22.AssertEventuallyResponsesClosed() + client21.AssertEventuallyResponsesClosed() + assertEventuallyLost(ctx, t, store, agent2.ID) + assertEventuallyLost(ctx, t, store, client21.ID) + assertEventuallyLost(ctx, t, store, client22.ID) err = coord1.Close() require.NoError(t, err) // this closes agent1, client12, client11 - err = agent1.recvErr(ctx, t) - require.ErrorIs(t, err, io.EOF) - err = client12.recvErr(ctx, t) - require.ErrorIs(t, err, io.EOF) - err = client11.recvErr(ctx, t) - require.ErrorIs(t, err, io.EOF) - assertEventuallyLost(ctx, t, store, agent1.id) - assertEventuallyLost(ctx, t, store, client11.id) - assertEventuallyLost(ctx, t, store, client12.id) - - // wait for all connections to close - err = agent1.close() - require.NoError(t, err) - agent1.waitForClose(ctx, t) - - err = agent2.close() - require.NoError(t, err) - agent2.waitForClose(ctx, t) - - err = client11.close() - require.NoError(t, err) - client11.waitForClose(ctx, t) - - err = client12.close() - require.NoError(t, err) - client12.waitForClose(ctx, t) - - err = client21.close() - require.NoError(t, err) - client21.waitForClose(ctx, t) - - err = client22.close() - require.NoError(t, err) - client22.waitForClose(ctx, t) + agent1.AssertEventuallyResponsesClosed() + client12.AssertEventuallyResponsesClosed() + client11.AssertEventuallyResponsesClosed() + assertEventuallyLost(ctx, t, store, agent1.ID) + assertEventuallyLost(ctx, t, store, client11.ID) + assertEventuallyLost(ctx, t, store, client12.ID) } // TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators. @@ -623,53 +543,42 @@ func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) { require.NoError(t, err) defer coord3.Close() - agent1 := newTestAgent(t, coord1, "agent1") - defer agent1.close() - agent2 := newTestAgent(t, coord2, "agent2", agent1.id) - defer agent2.close() + agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1") + defer agent1.Close(ctx) + agent2 := agpltest.NewPeer(ctx, t, coord2, "agent2", + agpltest.WithID(agent1.ID), agpltest.WithAuth(agpl.AgentCoordinateeAuth{ID: agent1.ID}), + ) + defer agent2.Close(ctx) - client := newTestClient(t, coord3, agent1.id) - defer client.close() + client := agpltest.NewClient(ctx, t, coord3, "client", agent1.ID) + defer client.Close(ctx) - client.sendNode(&agpl.Node{PreferredDERP: 3}) - assertEventuallyHasDERPs(ctx, t, agent1, 3) - assertEventuallyHasDERPs(ctx, t, agent2, 3) + client.UpdateDERP(3) + agent1.AssertEventuallyHasDERP(client.ID, 3) + agent2.AssertEventuallyHasDERP(client.ID, 3) - agent1.sendNode(&agpl.Node{PreferredDERP: 1}) - assertEventuallyHasDERPs(ctx, t, client, 1) + agent1.UpdateDERP(1) + client.AssertEventuallyHasDERP(agent1.ID, 1) // agent2's update overrides agent1 because it is newer - agent2.sendNode(&agpl.Node{PreferredDERP: 2}) - assertEventuallyHasDERPs(ctx, t, client, 2) + agent2.UpdateDERP(2) + client.AssertEventuallyHasDERP(agent1.ID, 2) // agent2 disconnects, and we should revert back to agent1 - err = agent2.close() - require.NoError(t, err) - err = agent2.recvErr(ctx, t) - require.ErrorIs(t, err, io.ErrClosedPipe) - agent2.waitForClose(ctx, t) - assertEventuallyHasDERPs(ctx, t, client, 1) + agent2.Close(ctx) + client.AssertEventuallyHasDERP(agent1.ID, 1) - agent1.sendNode(&agpl.Node{PreferredDERP: 11}) - assertEventuallyHasDERPs(ctx, t, client, 11) + agent1.UpdateDERP(11) + client.AssertEventuallyHasDERP(agent1.ID, 11) - client.sendNode(&agpl.Node{PreferredDERP: 31}) - assertEventuallyHasDERPs(ctx, t, agent1, 31) + client.UpdateDERP(31) + agent1.AssertEventuallyHasDERP(client.ID, 31) - err = agent1.close() - require.NoError(t, err) - err = agent1.recvErr(ctx, t) - require.ErrorIs(t, err, io.ErrClosedPipe) - agent1.waitForClose(ctx, t) + agent1.UngracefulDisconnect(ctx) + client.UngracefulDisconnect(ctx) - err = client.close() - require.NoError(t, err) - err = client.recvErr(ctx, t) - require.ErrorIs(t, err, io.ErrClosedPipe) - client.waitForClose(ctx, t) - - assertEventuallyLost(ctx, t, store, client.id) - assertEventuallyLost(ctx, t, store, agent1.id) + assertEventuallyLost(ctx, t, store, client.ID) + assertEventuallyLost(ctx, t, store, agent1.ID) } func TestPGCoordinator_Unhealthy(t *testing.T) { @@ -683,7 +592,13 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) calls := make(chan struct{}) + // first call succeeds, so that our Agent will successfully connect. + firstSucceeds := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). + Times(1). + Return(database.TailnetCoordinator{}, nil) + // next 3 fail, so the Coordinator becomes unhealthy, and we test that it disconnects the agent threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()). + After(firstSucceeds). Times(3). Do(func(_ context.Context, _ uuid.UUID) { <-calls }). Return(database.TailnetCoordinator{}, xerrors.New("test disconnect")) @@ -710,23 +625,23 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { err := uut.Close() require.NoError(t, err) }() - agent1 := newTestAgent(t, uut, "agent1") - defer agent1.close() + agent1 := agpltest.NewAgent(ctx, t, uut, "agent1") + defer agent1.Close(ctx) for i := 0; i < 3; i++ { select { case <-ctx.Done(): - t.Fatal("timeout") + t.Fatalf("timeout waiting for call %d", i+1) case calls <- struct{}{}: // OK } } // connected agent should be disconnected - agent1.waitForClose(ctx, t) + agent1.AssertEventuallyResponsesClosed() // new agent should immediately disconnect - agent2 := newTestAgent(t, uut, "agent2") - defer agent2.close() - agent2.waitForClose(ctx, t) + agent2 := agpltest.NewAgent(ctx, t, uut, "agent2") + defer agent2.Close(ctx) + agent2.AssertEventuallyResponsesClosed() // next heartbeats succeed, so we are healthy for i := 0; i < 2; i++ { @@ -737,14 +652,9 @@ func TestPGCoordinator_Unhealthy(t *testing.T) { // OK } } - agent3 := newTestAgent(t, uut, "agent3") - defer agent3.close() - select { - case <-agent3.closeChan: - t.Fatal("agent conn closed after we are healthy") - case <-time.After(time.Second): - // OK - } + agent3 := agpltest.NewAgent(ctx, t, uut, "agent3") + defer agent3.Close(ctx) + agent3.AssertNotClosed(time.Second) } func TestPGCoordinator_Node_Empty(t *testing.T) { @@ -840,43 +750,39 @@ func TestPGCoordinator_NoDeleteOnClose(t *testing.T) { require.NoError(t, err) defer coordinator.Close() - agent := newTestAgent(t, coordinator, "original") - defer agent.close() - agent.sendNode(&agpl.Node{PreferredDERP: 10}) + agent := agpltest.NewAgent(ctx, t, coordinator, "original") + defer agent.Close(ctx) + agent.UpdateDERP(10) - client := newTestClient(t, coordinator, agent.id) - defer client.close() + client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID) + defer client.Close(ctx) // Simulate some traffic to generate // a peer. - agentNodes := client.recvNodes(ctx, t) - require.Len(t, agentNodes, 1) - assert.Equal(t, 10, agentNodes[0].PreferredDERP) - client.sendNode(&agpl.Node{PreferredDERP: 11}) + client.AssertEventuallyHasDERP(agent.ID, 10) + client.UpdateDERP(11) - clientNodes := agent.recvNodes(ctx, t) - require.Len(t, clientNodes, 1) - assert.Equal(t, 11, clientNodes[0].PreferredDERP) + agent.AssertEventuallyHasDERP(client.ID, 11) - anode := coordinator.Node(agent.id) + anode := coordinator.Node(agent.ID) require.NotNil(t, anode) - cnode := coordinator.Node(client.id) + cnode := coordinator.Node(client.ID) require.NotNil(t, cnode) err = coordinator.Close() require.NoError(t, err) - assertEventuallyLost(ctx, t, store, agent.id) - assertEventuallyLost(ctx, t, store, client.id) + assertEventuallyLost(ctx, t, store, agent.ID) + assertEventuallyLost(ctx, t, store, client.ID) coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store) require.NoError(t, err) defer coordinator2.Close() - anode = coordinator2.Node(agent.id) + anode = coordinator2.Node(agent.ID) require.NotNil(t, anode) assert.Equal(t, 10, anode.PreferredDERP) - cnode = coordinator2.Node(client.id) + cnode = coordinator2.Node(client.ID) require.NotNil(t, cnode) assert.Equal(t, 11, cnode.PreferredDERP) } @@ -1007,144 +913,6 @@ func TestPGCoordinatorDual_PeerReconnect(t *testing.T) { p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) } -type testConn struct { - ws, serverWS net.Conn - nodeChan chan []*agpl.Node - sendNode func(node *agpl.Node) - errChan <-chan error - id uuid.UUID - closeChan chan struct{} -} - -func newTestConn(ids []uuid.UUID) *testConn { - a := &testConn{} - a.ws, a.serverWS = net.Pipe() - a.nodeChan = make(chan []*agpl.Node) - a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error { - a.nodeChan <- nodes - return nil - }) - if len(ids) > 1 { - panic("too many") - } - if len(ids) == 1 { - a.id = ids[0] - } else { - a.id = uuid.New() - } - a.closeChan = make(chan struct{}) - return a -} - -func newTestAgent(t *testing.T, coord agpl.CoordinatorV1, name string, id ...uuid.UUID) *testConn { - a := newTestConn(id) - go func() { - err := coord.ServeAgent(a.serverWS, a.id, name) - assert.NoError(t, err) - close(a.closeChan) - }() - return a -} - -func newTestClient(t *testing.T, coord agpl.CoordinatorV1, agentID uuid.UUID, id ...uuid.UUID) *testConn { - c := newTestConn(id) - go func() { - err := coord.ServeClient(c.serverWS, c.id, agentID) - assert.NoError(t, err) - close(c.closeChan) - }() - return c -} - -func (c *testConn) close() error { - return c.ws.Close() -} - -func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node { - t.Helper() - select { - case <-ctx.Done(): - t.Fatalf("testConn id %s: timeout receiving nodes ", c.id) - return nil - case nodes := <-c.nodeChan: - return nodes - } -} - -func (c *testConn) recvErr(ctx context.Context, t *testing.T) error { - t.Helper() - // pgCoord works on eventual consistency, so it sometimes sends extra node - // updates, and these block errors if not read from the nodes channel. - for { - select { - case nodes := <-c.nodeChan: - t.Logf("ignoring nodes update while waiting for error; id=%s, nodes=%+v", - c.id.String(), nodes) - continue - case <-ctx.Done(): - t.Fatal("timeout receiving error") - return ctx.Err() - case err := <-c.errChan: - return err - } - } -} - -func (c *testConn) waitForClose(ctx context.Context, t *testing.T) { - t.Helper() - select { - case <-ctx.Done(): - t.Fatal("timeout waiting for connection to close") - return - case <-c.closeChan: - return - } -} - -func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) { - t.Helper() - for { - nodes := c.recvNodes(ctx, t) - if len(nodes) != len(expected) { - t.Logf("expected %d, got %d nodes", len(expected), len(nodes)) - continue - } - - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if !slices.Contains(derps, e) { - t.Logf("expected DERP %d to be in %v", e, derps) - continue - } - return - } - } -} - -func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) { - t.Helper() - for { - select { - case <-ctx.Done(): - return - case nodes := <-c.nodeChan: - derps := make([]int, 0, len(nodes)) - for _, n := range nodes { - derps = append(derps, n.PreferredDERP) - } - for _, e := range expected { - if slices.Contains(derps, e) { - t.Fatalf("expected not to get DERP %d, but received it", e) - return - } - } - } - } -} - func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) { t.Helper() assert.Eventually(t, func() bool { diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go index dcadc4805d..de95c18577 100644 --- a/enterprise/tailnet/workspaceproxy.go +++ b/enterprise/tailnet/workspaceproxy.go @@ -1,19 +1,13 @@ package tailnet import ( - "bytes" "context" - "encoding/json" - "errors" "net" - "time" "github.com/google/uuid" - "golang.org/x/xerrors" "cdr.dev/slog" "github.com/coder/coder/v2/apiversion" - "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" agpl "github.com/coder/coder/v2/tailnet" ) @@ -38,10 +32,6 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin return err } switch major { - case 1: - coord := *(s.CoordPtr.Load()) - sub := coord.ServeMultiAgent(id) - return ServeWorkspaceProxy(ctx, conn, sub) case 2: auth := agpl.SingleTailnetCoordinateeAuth{} streamID := agpl.StreamID{ @@ -52,103 +42,6 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin return s.ServeConnV2(ctx, conn, streamID) default: s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) - return xerrors.New("unsupported version") - } -} - -func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { - go func() { - //nolint:staticcheck - err := forwardNodesToWorkspaceProxy(ctx, conn, ma) - //nolint:staticcheck - if err != nil { - _ = conn.Close() - } - }() - - decoder := json.NewDecoder(conn) - for { - var msg wsproxysdk.CoordinateMessage - err := decoder.Decode(&msg) - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return xerrors.Errorf("read json: %w", err) - } - - switch msg.Type { - case wsproxysdk.CoordinateMessageTypeSubscribe: - err := ma.SubscribeAgent(msg.AgentID) - if err != nil { - return xerrors.Errorf("subscribe agent: %w", err) - } - case wsproxysdk.CoordinateMessageTypeUnsubscribe: - err := ma.UnsubscribeAgent(msg.AgentID) - if err != nil { - return xerrors.Errorf("unsubscribe agent: %w", err) - } - case wsproxysdk.CoordinateMessageTypeNodeUpdate: - pn, err := agpl.NodeToProto(msg.Node) - if err != nil { - return err - } - err = ma.UpdateSelf(pn) - if err != nil { - return xerrors.Errorf("update self: %w", err) - } - - default: - return xerrors.Errorf("unknown message type %q", msg.Type) - } - } -} - -// Linter fails because this function always returns an error. This function blocks -// until it errors, so this is ok. -// -//nolint:staticcheck -func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { - var lastData []byte - for { - resp, ok := ma.NextUpdate(ctx) - if !ok { - return xerrors.New("multiagent is closed") - } - nodes, err := agpl.OnlyNodeUpdates(resp) - if err != nil { - return xerrors.Errorf("failed to convert response: %w", err) - } - data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes}) - if err != nil { - return err - } - if bytes.Equal(lastData, data) { - continue - } - - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = conn.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - return err - } - _, err = conn.Write(data) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - return err - } - - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = conn.SetWriteDeadline(time.Time{}) - if err != nil { - return err - } - lastData = data + return agpl.ErrUnsupportedVersion } } diff --git a/site/e2e/tests/outdatedCLI.spec.ts b/site/e2e/tests/outdatedCLI.spec.ts index 6118f195ff..22301483e0 100644 --- a/site/e2e/tests/outdatedCLI.spec.ts +++ b/site/e2e/tests/outdatedCLI.spec.ts @@ -11,6 +11,7 @@ import { } from "../helpers"; import { beforeCoderTest } from "../hooks"; +// we no longer support versions prior to Tailnet v2 API support: https://github.com/coder/coder/commit/059e533544a0268acbc8831006b2858ead2f0d8e const clientVersion = "v2.8.0"; test.beforeEach(({ page }) => beforeCoderTest(page)); diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 6586c9917c..cc50c792f1 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -2,11 +2,9 @@ package tailnet import ( "context" - "encoding/json" "fmt" "html/template" "io" - "net" "net/http" "net/netip" "sync" @@ -14,7 +12,6 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" - "nhooyr.io/websocket" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -22,35 +19,23 @@ import ( "github.com/coder/coder/v2/tailnet/proto" ) +const ( + // ResponseBufferSize is the max number of responses to buffer per connection before we start + // dropping updates + ResponseBufferSize = 512 + // RequestBufferSize is the max number of requests to buffer per connection + RequestBufferSize = 32 +) + // Coordinator exchanges nodes with agents to establish connections. // ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐ // │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│ // └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘ // Coordinators have different guarantees for HA support. type Coordinator interface { - CoordinatorV1 CoordinatorV2 } -type CoordinatorV1 interface { - // ServeHTTPDebug serves a debug webpage that shows the internal state of - // the coordinator. - ServeHTTPDebug(w http.ResponseWriter, r *http.Request) - // Node returns an in-memory node by ID. - Node(id uuid.UUID) *Node - // ServeClient accepts a WebSocket connection that wants to connect to an agent - // with the specified ID. - ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error - // ServeAgent accepts a WebSocket connection to an agent that listens to - // incoming connections and publishes node updates. - // Name is just used for debug information. It can be left blank. - ServeAgent(conn net.Conn, id uuid.UUID, name string) error - // Close closes the coordinator. - Close() error - - ServeMultiAgent(id uuid.UUID) MultiAgentConn -} - // CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API. type CoordinatorV2 interface { // ServeHTTPDebug serves a debug webpage that shows the internal state of @@ -60,6 +45,7 @@ type CoordinatorV2 interface { Node(id uuid.UUID) *Node Close() error Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) + ServeMultiAgent(id uuid.UUID) MultiAgentConn } // Node represents a node in the network. @@ -389,44 +375,6 @@ func (c *inMemoryCoordination) Close() error { } } -// ServeCoordinator matches the RW structure of a coordinator to exchange node messages. -func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) { - errChan := make(chan error, 1) - sendErr := func(err error) { - select { - case errChan <- err: - default: - } - } - go func() { - decoder := json.NewDecoder(conn) - for { - var nodes []*Node - err := decoder.Decode(&nodes) - if err != nil { - sendErr(xerrors.Errorf("read: %w", err)) - return - } - err = updateNodes(nodes) - if err != nil { - sendErr(xerrors.Errorf("update nodes: %w", err)) - } - } - }() - - return func(node *Node) { - data, err := json.Marshal(node) - if err != nil { - sendErr(xerrors.Errorf("marshal node: %w", err)) - return - } - _, err = conn.Write(data) - if err != nil { - sendErr(xerrors.Errorf("write: %w", err)) - } - }, errChan -} - const LoggerName = "coord" var ( @@ -540,11 +488,11 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge }, }).Init() - go v1RespLoop(ctx, cancel, logger, m, resps) + go qRespLoop(ctx, cancel, logger, m, resps) return m } -// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines; +// core is an in-memory structure of peer mappings. Its methods may be called from multiple goroutines; // it is protected by a mutex to ensure data stay consistent. type core struct { logger slog.Logger @@ -607,42 +555,6 @@ func (c *core) node(id uuid.UUID) *Node { return v1Node } -// ServeClient accepts a WebSocket connection that wants to connect to an agent -// with the specified ID. -func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ServeClientV1(ctx, c.core.logger, c, conn, id, agentID) -} - -// ServeClientV1 adapts a v1 Client to a v2 Coordinator -func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, agent uuid.UUID) error { - logger = logger.With(slog.F("client_id", id), slog.F("agent_id", agent)) - defer func() { - err := conn.Close() - if err != nil { - logger.Debug(ctx, "closing client connection", slog.Error(err)) - } - }() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - reqs, resps := c.Coordinate(ctx, id, id.String(), ClientCoordinateeAuth{AgentID: agent}) - err := SendCtx(ctx, reqs, &proto.CoordinateRequest{ - AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)}, - }) - if err != nil { - // can only be a context error, no need to log here. - return err - } - - tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient) - go tc.SendUpdates() - go v1RespLoop(ctx, cancel, logger, tc, resps) - go v1ReqLoop(ctx, cancel, logger, conn, reqs) - <-ctx.Done() - return nil -} - func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error { c.mutex.Lock() defer c.mutex.Unlock() @@ -887,34 +799,6 @@ func (c *core) removePeerLocked(id uuid.UUID, kind proto.CoordinateResponse_Peer delete(c.peers, id) } -// ServeAgent accepts a WebSocket connection to an agent that -// listens to incoming connections and publishes node updates. -func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ServeAgentV1(ctx, c.core.logger, c, conn, id, name) -} - -func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, name string) error { - logger = logger.With(slog.F("agent_id", id), slog.F("name", name)) - defer func() { - logger.Debug(ctx, "closing agent connection") - err := conn.Close() - logger.Debug(ctx, "closed agent connection", slog.Error(err)) - }() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - logger.Debug(ctx, "starting new agent connection") - reqs, resps := c.Coordinate(ctx, id, name, AgentCoordinateeAuth{ID: id}) - tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent) - go tc.SendUpdates() - go v1RespLoop(ctx, cancel, logger, tc, resps) - go v1ReqLoop(ctx, cancel, logger, conn, reqs) - <-ctx.Done() - logger.Debug(ctx, "ending agent connection") - return nil -} - // Close closes all of the open connections in the coordinator and stops the // coordinator from accepting new connections. func (c *coordinator) Close() error { @@ -1073,44 +957,7 @@ func RecvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) { } } -func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, - conn net.Conn, reqs chan<- *proto.CoordinateRequest, -) { - defer close(reqs) - defer cancel() - decoder := json.NewDecoder(conn) - for { - var node Node - err := decoder.Decode(&node) - if err != nil { - if xerrors.Is(err, io.EOF) || - xerrors.Is(err, io.ErrClosedPipe) || - xerrors.Is(err, context.Canceled) || - xerrors.Is(err, context.DeadlineExceeded) || - websocket.CloseStatus(err) > 0 { - logger.Debug(ctx, "v1ReqLoop exiting", slog.Error(err)) - } else { - logger.Info(ctx, "v1ReqLoop failed to decode Node update", slog.Error(err)) - } - return - } - logger.Debug(ctx, "v1ReqLoop got node update", slog.F("node", node)) - pn, err := NodeToProto(&node) - if err != nil { - logger.Critical(ctx, "v1ReqLoop failed to convert v1 node", slog.F("node", node), slog.Error(err)) - return - } - req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{ - Node: pn, - }} - if err := SendCtx(ctx, reqs, req); err != nil { - logger.Debug(ctx, "v1ReqLoop ctx expired", slog.Error(err)) - return - } - } -} - -func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) { +func qRespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) { defer func() { cErr := q.Close() if cErr != nil { @@ -1121,13 +968,13 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg for { resp, err := RecvCtx(ctx, resps) if err != nil { - logger.Debug(ctx, "v1RespLoop done reading responses", slog.Error(err)) + logger.Debug(ctx, "qRespLoop done reading responses", slog.Error(err)) return } - logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp)) + logger.Debug(ctx, "qRespLoop got response", slog.F("resp", resp)) err = q.Enqueue(resp) if err != nil && !xerrors.Is(err, context.Canceled) { - logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err)) + logger.Error(ctx, "qRespLoop failed to enqueue v1 update", slog.Error(err)) } } } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 851d489150..400084fafa 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -2,10 +2,7 @@ package tailnet_test import ( "context" - "encoding/json" "net" - "net/http" - "net/http/httptest" "net/netip" "sync" "sync/atomic" @@ -13,10 +10,8 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "nhooyr.io/websocket" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -34,162 +29,107 @@ func TestCoordinator(t *testing.T) { t.Run("ClientWithoutAgent", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - ctx := testutil.Context(t, testutil.WaitMedium) + ctx := testutil.Context(t, testutil.WaitShort) coordinator := tailnet.NewCoordinator(logger) defer func() { err := coordinator.Close() require.NoError(t, err) }() - client, server := net.Pipe() - sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { - return nil - }) - id := uuid.New() - closeChan := make(chan struct{}) - go func() { - err := coordinator.ServeClient(server, id, uuid.New()) - assert.NoError(t, err) - close(closeChan) - }() - sendNode(&tailnet.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(tailnet.IP(), 128), - }, - PreferredDERP: 10, + + client := test.NewClient(ctx, t, coordinator, "client", uuid.New()) + defer client.Close(ctx) + client.UpdateNode(&proto.Node{ + Addresses: []string{netip.PrefixFrom(tailnet.IP(), 128).String()}, + PreferredDerp: 10, }) require.Eventually(t, func() bool { - return coordinator.Node(id) != nil + return coordinator.Node(client.ID) != nil }, testutil.WaitShort, testutil.IntervalFast) - require.NoError(t, client.Close()) - require.NoError(t, server.Close()) - _ = testutil.RequireRecvCtx(ctx, t, errChan) - _ = testutil.RequireRecvCtx(ctx, t, closeChan) }) t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx := testutil.Context(t, testutil.WaitMedium) + ctx := testutil.Context(t, testutil.WaitShort) coordinator := tailnet.NewCoordinator(logger) defer func() { err := coordinator.Close() require.NoError(t, err) }() - client, server := net.Pipe() - sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { - return nil - }) - id := uuid.New() - closeChan := make(chan struct{}) - go func() { - err := coordinator.ServeClient(server, id, uuid.New()) - assert.NoError(t, err) - close(closeChan) - }() - sendNode(&tailnet.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(tailnet.IP(), 64), - }, - PreferredDERP: 10, - }) - _ = testutil.RequireRecvCtx(ctx, t, errChan) - _ = testutil.RequireRecvCtx(ctx, t, closeChan) + client := test.NewClient(ctx, t, coordinator, "client", uuid.New()) + defer client.Close(ctx) + + client.UpdateNode(&proto.Node{ + Addresses: []string{ + netip.PrefixFrom(tailnet.IP(), 64).String(), + }, + PreferredDerp: 10, + }) + client.AssertEventuallyResponsesClosed() }) t.Run("AgentWithoutClients", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - ctx := testutil.Context(t, testutil.WaitMedium) + ctx := testutil.Context(t, testutil.WaitShort) coordinator := tailnet.NewCoordinator(logger) defer func() { err := coordinator.Close() require.NoError(t, err) }() - client, server := net.Pipe() - sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { - return nil - }) - id := uuid.New() - closeChan := make(chan struct{}) - go func() { - err := coordinator.ServeAgent(server, id, "") - assert.NoError(t, err) - close(closeChan) - }() - sendNode(&tailnet.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(tailnet.IPFromUUID(id), 128), + + agent := test.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateNode(&proto.Node{ + Addresses: []string{ + netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 128).String(), }, - PreferredDERP: 10, + PreferredDerp: 10, }) require.Eventually(t, func() bool { - return coordinator.Node(id) != nil + return coordinator.Node(agent.ID) != nil }, testutil.WaitShort, testutil.IntervalFast) - err := client.Close() - require.NoError(t, err) - _ = testutil.RequireRecvCtx(ctx, t, errChan) - _ = testutil.RequireRecvCtx(ctx, t, closeChan) }) t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx := testutil.Context(t, testutil.WaitMedium) + ctx := testutil.Context(t, testutil.WaitShort) coordinator := tailnet.NewCoordinator(logger) defer func() { err := coordinator.Close() require.NoError(t, err) }() - client, server := net.Pipe() - sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { - return nil - }) - id := uuid.New() - closeChan := make(chan struct{}) - go func() { - err := coordinator.ServeAgent(server, id, "") - assert.NoError(t, err) - close(closeChan) - }() - sendNode(&tailnet.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(tailnet.IP(), 128), + agent := test.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateNode(&proto.Node{ + Addresses: []string{ + netip.PrefixFrom(tailnet.IP(), 128).String(), }, - PreferredDERP: 10, + PreferredDerp: 10, }) - _ = testutil.RequireRecvCtx(ctx, t, errChan) - _ = testutil.RequireRecvCtx(ctx, t, closeChan) + agent.AssertEventuallyResponsesClosed() }) t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - ctx := testutil.Context(t, testutil.WaitMedium) + ctx := testutil.Context(t, testutil.WaitShort) coordinator := tailnet.NewCoordinator(logger) defer func() { err := coordinator.Close() require.NoError(t, err) }() - client, server := net.Pipe() - sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error { - return nil - }) - id := uuid.New() - closeChan := make(chan struct{}) - go func() { - err := coordinator.ServeAgent(server, id, "") - assert.NoError(t, err) - close(closeChan) - }() - sendNode(&tailnet.Node{ - Addresses: []netip.Prefix{ - netip.PrefixFrom(tailnet.IPFromUUID(id), 64), + agent := test.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateNode(&proto.Node{ + Addresses: []string{ + netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 64).String(), }, - PreferredDERP: 10, + PreferredDerp: 10, }) - _ = testutil.RequireRecvCtx(ctx, t, errChan) - _ = testutil.RequireRecvCtx(ctx, t, closeChan) + agent.AssertEventuallyResponsesClosed() }) t.Run("AgentWithClient", func(t *testing.T) { @@ -201,180 +141,71 @@ func TestCoordinator(t *testing.T) { require.NoError(t, err) }() - // in this test we use real websockets to test use of deadlines - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() - agentWS, agentServerWS := websocketConn(ctx, t) - defer agentWS.Close() - agentNodeChan := make(chan []*tailnet.Node) - sendAgentNode, agentErrChan := tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error { - agentNodeChan <- nodes - return nil - }) - agentID := uuid.New() - closeAgentChan := make(chan struct{}) - go func() { - err := coordinator.ServeAgent(agentServerWS, agentID, "") - assert.NoError(t, err) - close(closeAgentChan) - }() - sendAgentNode(&tailnet.Node{PreferredDERP: 1}) + agent := test.NewAgent(ctx, t, coordinator, "agent") + defer agent.Close(ctx) + agent.UpdateDERP(1) require.Eventually(t, func() bool { - return coordinator.Node(agentID) != nil + return coordinator.Node(agent.ID) != nil }, testutil.WaitShort, testutil.IntervalFast) - clientWS, clientServerWS := websocketConn(ctx, t) - defer clientWS.Close() - defer clientServerWS.Close() - clientNodeChan := make(chan []*tailnet.Node) - sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error { - clientNodeChan <- nodes - return nil - }) - clientID := uuid.New() - closeClientChan := make(chan struct{}) - go func() { - err := coordinator.ServeClient(clientServerWS, clientID, agentID) - assert.NoError(t, err) - close(closeClientChan) - }() - agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan) - require.Len(t, agentNodes, 1) + client := test.NewClient(ctx, t, coordinator, "client", agent.ID) + defer client.Close(ctx) + client.AssertEventuallyHasDERP(agent.ID, 1) - sendClientNode(&tailnet.Node{PreferredDERP: 2}) - clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan) - require.Len(t, clientNodes, 1) - - // wait longer than the internal wait timeout. - // this tests for regression of https://github.com/coder/coder/issues/7428 - time.Sleep(tailnet.WriteTimeout * 3 / 2) + client.UpdateDERP(2) + agent.AssertEventuallyHasDERP(client.ID, 2) // Ensure an update to the agent node reaches the client! - sendAgentNode(&tailnet.Node{PreferredDERP: 3}) - agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan) - require.Len(t, agentNodes, 1) + agent.UpdateDERP(3) + client.AssertEventuallyHasDERP(agent.ID, 3) - // Close the agent WebSocket so a new one can connect. - err := agentWS.Close() - require.NoError(t, err) - _ = testutil.RequireRecvCtx(ctx, t, agentErrChan) - _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan) + // Close the agent so a new one can connect. + agent.Close(ctx) // Create a new agent connection. This is to simulate a reconnect! - agentWS, agentServerWS = net.Pipe() - defer agentWS.Close() - agentNodeChan = make(chan []*tailnet.Node) - _, agentErrChan = tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error { - agentNodeChan <- nodes - return nil - }) - closeAgentChan = make(chan struct{}) - go func() { - err := coordinator.ServeAgent(agentServerWS, agentID, "") - assert.NoError(t, err) - close(closeAgentChan) - }() - // Ensure the existing listening client sends its node immediately! - clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan) - require.Len(t, clientNodes, 1) - - err = agentWS.Close() - require.NoError(t, err) - _ = testutil.RequireRecvCtx(ctx, t, agentErrChan) - _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan) - - err = clientWS.Close() - require.NoError(t, err) - _ = testutil.RequireRecvCtx(ctx, t, clientErrChan) - _ = testutil.RequireRecvCtx(ctx, t, closeClientChan) + agent = test.NewPeer(ctx, t, coordinator, "agent", test.WithID(agent.ID)) + defer agent.Close(ctx) + // Ensure the agent gets the existing client node immediately! + agent.AssertEventuallyHasDERP(client.ID, 2) }) t.Run("AgentDoubleConnect", func(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) coordinator := tailnet.NewCoordinator(logger) - ctx := testutil.Context(t, testutil.WaitLong) + ctx := testutil.Context(t, testutil.WaitShort) - agentWS1, agentServerWS1 := net.Pipe() - defer agentWS1.Close() - agentNodeChan1 := make(chan []*tailnet.Node) - sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error { - t.Logf("agent1 got node update: %v", nodes) - agentNodeChan1 <- nodes - return nil - }) agentID := uuid.New() - closeAgentChan1 := make(chan struct{}) - go func() { - err := coordinator.ServeAgent(agentServerWS1, agentID, "") - assert.NoError(t, err) - close(closeAgentChan1) - }() - sendAgentNode1(&tailnet.Node{PreferredDERP: 1}) + agent1 := test.NewPeer(ctx, t, coordinator, "agent1", test.WithID(agentID)) + defer agent1.Close(ctx) + agent1.UpdateDERP(1) require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) - clientWS, clientServerWS := net.Pipe() - defer clientWS.Close() - defer clientServerWS.Close() - clientNodeChan := make(chan []*tailnet.Node) - sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error { - t.Logf("client got node update: %v", nodes) - clientNodeChan <- nodes - return nil - }) - clientID := uuid.New() - closeClientChan := make(chan struct{}) - go func() { - err := coordinator.ServeClient(clientServerWS, clientID, agentID) - assert.NoError(t, err) - close(closeClientChan) - }() - agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan) - require.Len(t, agentNodes, 1) - sendClientNode(&tailnet.Node{PreferredDERP: 2}) - clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1) - require.Len(t, clientNodes, 1) + client := test.NewPeer(ctx, t, coordinator, "client") + defer client.Close(ctx) + client.AddTunnel(agentID) + client.AssertEventuallyHasDERP(agent1.ID, 1) + + client.UpdateDERP(2) + agent1.AssertEventuallyHasDERP(client.ID, 2) // Ensure an update to the agent node reaches the client! - sendAgentNode1(&tailnet.Node{PreferredDERP: 3}) - agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan) - require.Len(t, agentNodes, 1) + agent1.UpdateDERP(3) + client.AssertEventuallyHasDERP(agent1.ID, 3) // Create a new agent connection without disconnecting the old one. - agentWS2, agentServerWS2 := net.Pipe() - defer agentWS2.Close() - agentNodeChan2 := make(chan []*tailnet.Node) - _, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error { - t.Logf("agent2 got node update: %v", nodes) - agentNodeChan2 <- nodes - return nil - }) - closeAgentChan2 := make(chan struct{}) - go func() { - err := coordinator.ServeAgent(agentServerWS2, agentID, "") - assert.NoError(t, err) - close(closeAgentChan2) - }() + agent2 := test.NewPeer(ctx, t, coordinator, "agent2", test.WithID(agentID)) + defer agent2.Close(ctx) - // Ensure the existing listening client sends it's node immediately! - clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2) - require.Len(t, clientNodes, 1) + // Ensure the existing client node gets sent immediately! + agent2.AssertEventuallyHasDERP(client.ID, 2) - // This original agent websocket should've been closed forcefully. - _ = testutil.RequireRecvCtx(ctx, t, agentErrChan1) - _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1) - - err := agentWS2.Close() - require.NoError(t, err) - _ = testutil.RequireRecvCtx(ctx, t, agentErrChan2) - _ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2) - - err = clientWS.Close() - require.NoError(t, err) - _ = testutil.RequireRecvCtx(ctx, t, clientErrChan) - _ = testutil.RequireRecvCtx(ctx, t, closeClientChan) + // This original agent channels should've been closed forcefully. + agent1.AssertEventuallyResponsesClosed() }) t.Run("AgentAck", func(t *testing.T) { @@ -396,89 +227,6 @@ func TestCoordinator(t *testing.T) { }) } -// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on -// https://github.com/coder/coder/issues/7295 -func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) { - t.Parallel() - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - coordinator := tailnet.NewCoordinator(logger) - agentWS, agentServerWS := net.Pipe() - defer agentWS.Close() - - agentID := uuid.New() - go func() { - err := coordinator.ServeAgent(agentServerWS, agentID, "") - assert.NoError(t, err) - }() - - // send an agent update before the client connects so that there is - // node data available to send right away. - aNode := tailnet.Node{PreferredDERP: 0} - aData, err := json.Marshal(&aNode) - require.NoError(t, err) - err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort)) - require.NoError(t, err) - _, err = agentWS.Write(aData) - require.NoError(t, err) - - require.Eventually(t, func() bool { - return coordinator.Node(agentID) != nil - }, testutil.WaitShort, testutil.IntervalFast) - - // Connect from the client - clientWS, clientServerWS := net.Pipe() - defer clientWS.Close() - clientID := uuid.New() - go func() { - err := coordinator.ServeClient(clientServerWS, clientID, agentID) - assert.NoError(t, err) - }() - - // peek one byte from the node update, so we know the coordinator is - // trying to write to the client. - // buffer needs to be 2 characters longer because return value is a list - // so, it needs [ and ] - buf := make([]byte, len(aData)+2) - err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort)) - require.NoError(t, err) - n, err := clientWS.Read(buf[:1]) - require.NoError(t, err) - require.Equal(t, 1, n) - - // send a second update - aNode.PreferredDERP = 1 - require.NoError(t, err) - aData, err = json.Marshal(&aNode) - require.NoError(t, err) - err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort)) - require.NoError(t, err) - _, err = agentWS.Write(aData) - require.NoError(t, err) - - // read the rest of the update from the client, should be initial node. - err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort)) - require.NoError(t, err) - n, err = clientWS.Read(buf[1:]) - require.NoError(t, err) - var cNodes []*tailnet.Node - err = json.Unmarshal(buf[:n+1], &cNodes) - require.NoError(t, err) - require.Len(t, cNodes, 1) - require.Equal(t, 0, cNodes[0].PreferredDERP) - - // read second update - // without a fix for https://github.com/coder/coder/issues/7295 our - // read would time out here. - err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort)) - require.NoError(t, err) - n, err = clientWS.Read(buf) - require.NoError(t, err) - err = json.Unmarshal(buf[:n], &cNodes) - require.NoError(t, err) - require.Len(t, cNodes, 1) - require.Equal(t, 1, cNodes[0].PreferredDERP) -} - func TestCoordinator_BidirectionalTunnels(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) @@ -521,29 +269,6 @@ func TestCoordinator_MultiAgent_CoordClose(t *testing.T) { ma1.RequireEventuallyClosed(ctx) } -func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) { - t.Helper() - sc := make(chan net.Conn, 1) - s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - wss, err := websocket.Accept(rw, r, nil) - require.NoError(t, err) - conn := websocket.NetConn(r.Context(), wss, websocket.MessageBinary) - sc <- conn - close(sc) // there can be only one - - // hold open until context canceled - <-ctx.Done() - })) - t.Cleanup(s.Close) - // nolint: bodyclose - wsc, _, err := websocket.Dial(ctx, s.URL, nil) - require.NoError(t, err) - client = websocket.NetConn(ctx, wsc, websocket.MessageBinary) - server, ok := <-sc - require.True(t, ok) - return client, server -} - func TestInMemoryCoordination(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) diff --git a/tailnet/proto/version.go b/tailnet/proto/version.go index e069b2d2f9..4eaf60f2a9 100644 --- a/tailnet/proto/version.go +++ b/tailnet/proto/version.go @@ -9,4 +9,4 @@ const ( CurrentMinor = 2 ) -var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1) +var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor) diff --git a/tailnet/service.go b/tailnet/service.go index 22111ce2fe..28a054dd8d 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -21,6 +21,8 @@ import ( "github.com/coder/quartz" ) +var ErrUnsupportedVersion = xerrors.New("unsupported version") + type streamIDContextKey struct{} // StreamID identifies the caller of the CoordinateTailnet RPC. We store this @@ -47,7 +49,7 @@ type ClientServiceOptions struct { } // ClientService is a tailnet coordination service that accepts a connection and version from a -// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol. +// tailnet client, and support versions 2.x of the Tailnet API protocol. type ClientService struct { Logger slog.Logger CoordPtr *atomic.Pointer[Coordinator] @@ -94,9 +96,6 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne return err } switch major { - case 1: - coord := *(s.CoordPtr.Load()) - return coord.ServeClient(conn, id, agent) case 2: auth := ClientCoordinateeAuth{AgentID: agent} streamID := StreamID{ @@ -107,7 +106,7 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne return s.ServeConnV2(ctx, conn, streamID) default: s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) - return xerrors.New("unsupported version") + return ErrUnsupportedVersion } } diff --git a/tailnet/service_test.go b/tailnet/service_test.go index 71a7fdacd8..0f4b4795c4 100644 --- a/tailnet/service_test.go +++ b/tailnet/service_test.go @@ -162,21 +162,8 @@ func TestClientService_ServeClient_V1(t *testing.T) { errCh <- err }() - call := testutil.RequireRecvCtx(ctx, t, fCoord.ServeClientCalls) - require.NotNil(t, call) - require.Equal(t, call.ID, clientID) - require.Equal(t, call.Agent, agentID) - require.Equal(t, s, call.Conn) - expectedError := xerrors.New("test error") - select { - case call.ErrCh <- expectedError: - // ok! - case <-ctx.Done(): - t.Fatalf("timeout sending error") - } - err = testutil.RequireRecvCtx(ctx, t, errCh) - require.ErrorIs(t, err, expectedError) + require.ErrorIs(t, err, tailnet.ErrUnsupportedVersion) } func TestNetworkTelemetryBatcher(t *testing.T) { diff --git a/tailnet/tailnettest/coordinatormock.go b/tailnet/tailnettest/coordinatormock.go index 6225b8c86a..e408a4b8ec 100644 --- a/tailnet/tailnettest/coordinatormock.go +++ b/tailnet/tailnettest/coordinatormock.go @@ -11,7 +11,6 @@ package tailnettest import ( context "context" - net "net" http "net/http" reflect "reflect" @@ -87,34 +86,6 @@ func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0) } -// ServeAgent mocks base method. -func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// ServeAgent indicates an expected call of ServeAgent. -func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2) -} - -// ServeClient mocks base method. -func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 -} - -// ServeClient indicates an expected call of ServeClient. -func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2) -} - // ServeHTTPDebug mocks base method. func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) { m.ctrl.T.Helper() diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 9b34c9bd3d..3dd2430ca2 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -161,7 +161,7 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap { type TestMultiAgent struct { t testing.TB - id uuid.UUID + ID uuid.UUID a tailnet.MultiAgentConn nodeKey []byte discoKey string @@ -172,8 +172,8 @@ func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent require.NoError(t, err) dk, err := key.NewDisco().Public().MarshalText() require.NoError(t, err) - m := &TestMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)} - m.a = coord.ServeMultiAgent(m.id) + m := &TestMultiAgent{t: t, ID: uuid.New(), nodeKey: nk, discoKey: string(dk)} + m.a = coord.ServeMultiAgent(m.ID) return m } @@ -277,8 +277,7 @@ func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) { } type FakeCoordinator struct { - CoordinateCalls chan *FakeCoordinate - ServeClientCalls chan *FakeServeClient + CoordinateCalls chan *FakeCoordinate } func (*FakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) { @@ -289,21 +288,6 @@ func (*FakeCoordinator) Node(uuid.UUID) *tailnet.Node { panic("unimplemented") } -func (f *FakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { - errCh := make(chan error) - f.ServeClientCalls <- &FakeServeClient{ - Conn: conn, - ID: id, - Agent: agent, - ErrCh: errCh, - } - return <-errCh -} - -func (*FakeCoordinator) ServeAgent(net.Conn, uuid.UUID, string) error { - panic("unimplemented") -} - func (*FakeCoordinator) Close() error { panic("unimplemented") } @@ -328,8 +312,7 @@ func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name str func NewFakeCoordinator() *FakeCoordinator { return &FakeCoordinator{ - CoordinateCalls: make(chan *FakeCoordinate, 100), - ServeClientCalls: make(chan *FakeServeClient, 100), + CoordinateCalls: make(chan *FakeCoordinate, 100), } } diff --git a/tailnet/test/peer.go b/tailnet/test/peer.go index 1b08d6886a..ce9a507499 100644 --- a/tailnet/test/peer.go +++ b/tailnet/test/peer.go @@ -3,10 +3,12 @@ package test import ( "context" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "golang.org/x/xerrors" + "tailscale.com/types/key" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" @@ -18,43 +20,73 @@ type PeerStatus struct { readyForHandshake bool } +type PeerOption func(*Peer) + +func WithID(id uuid.UUID) PeerOption { + return func(p *Peer) { + p.ID = id + } +} + +func WithAuth(auth tailnet.CoordinateeAuth) PeerOption { + return func(p *Peer) { + p.auth = auth + } +} + type Peer struct { ctx context.Context cancel context.CancelFunc t testing.TB ID uuid.UUID + auth tailnet.CoordinateeAuth name string + nodeKey key.NodePublic + discoKey key.DiscoPublic resps <-chan *proto.CoordinateResponse reqs chan<- *proto.CoordinateRequest peers map[uuid.UUID]PeerStatus peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate } -func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer { +func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, opts ...PeerOption) *Peer { p := &Peer{ t: t, name: name, peers: make(map[uuid.UUID]PeerStatus), peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate), + ID: uuid.New(), + // SingleTailnetCoordinateeAuth allows connections to arbitrary peers + auth: tailnet.SingleTailnetCoordinateeAuth{}, + // required for converting to and from protobuf, so we always include them + nodeKey: key.NewNode().Public(), + discoKey: key.NewDisco().Public(), } p.ctx, p.cancel = context.WithCancel(ctx) - if len(id) > 1 { - t.Fatal("too many") + for _, opt := range opts { + opt(p) } - if len(id) == 1 { - p.ID = id[0] - } else { - p.ID = uuid.New() - } - // SingleTailnetTunnelAuth allows connections to arbitrary peers - p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetCoordinateeAuth{}) + + p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, p.auth) + return p +} + +// NewAgent is a wrapper around NewPeer, creating a peer with Agent auth tied to its ID +func NewAgent(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string) *Peer { + id := uuid.New() + return NewPeer(ctx, t, coord, name, WithID(id), WithAuth(tailnet.AgentCoordinateeAuth{ID: id})) +} + +// NewClient is a wrapper around NewPeer, creating a peer with Client auth tied to the provided agentID +func NewClient(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, agentID uuid.UUID) *Peer { + p := NewPeer(ctx, t, coord, name, WithAuth(tailnet.ClientCoordinateeAuth{AgentID: agentID})) + p.AddTunnel(agentID) return p } func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) { p.t.Helper() - - p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, tailnet.SingleTailnetCoordinateeAuth{}) + p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, p.auth) } func (p *Peer) AddTunnel(other uuid.UUID) { @@ -71,7 +103,19 @@ func (p *Peer) AddTunnel(other uuid.UUID) { func (p *Peer) UpdateDERP(derp int32) { p.t.Helper() - req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}} + node := &proto.Node{PreferredDerp: derp} + p.UpdateNode(node) +} + +func (p *Peer) UpdateNode(node *proto.Node) { + p.t.Helper() + nk, err := p.nodeKey.MarshalBinary() + assert.NoError(p.t, err) + node.Key = nk + dk, err := p.discoKey.MarshalText() + assert.NoError(p.t, err) + node.Disco = string(dk) + req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}} select { case <-p.ctx.Done(): p.t.Errorf("timeout updating node for %s", p.name) @@ -115,13 +159,37 @@ func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) { if ok && o.preferredDERP == derp { return } - if err := p.handleOneResp(); err != nil { + if err := p.readOneResp(); err != nil { assert.NoError(p.t, err) return } } } +func (p *Peer) AssertNeverHasDERPs(ctx context.Context, other uuid.UUID, expected ...int32) { + p.t.Helper() + for { + select { + case <-ctx.Done(): + return + case resp, ok := <-p.resps: + if !ok { + p.t.Errorf("response channel closed") + } + if !assert.NoError(p.t, p.handleResp(resp)) { + return + } + derp, ok := p.peers[other] + if !ok { + continue + } + if !assert.NotContains(p.t, expected, derp) { + return + } + } + } +} + func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) { p.t.Helper() for { @@ -129,7 +197,7 @@ func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) { if !ok { return } - if err := p.handleOneResp(); err != nil { + if err := p.readOneResp(); err != nil { assert.NoError(p.t, err) return } @@ -143,7 +211,7 @@ func (p *Peer) AssertEventuallyLost(other uuid.UUID) { if o.status == proto.CoordinateResponse_PeerUpdate_LOST { return } - if err := p.handleOneResp(); err != nil { + if err := p.readOneResp(); err != nil { assert.NoError(p.t, err) return } @@ -153,7 +221,7 @@ func (p *Peer) AssertEventuallyLost(other uuid.UUID) { func (p *Peer) AssertEventuallyResponsesClosed() { p.t.Helper() for { - err := p.handleOneResp() + err := p.readOneResp() if xerrors.Is(err, responsesClosed) { return } @@ -163,6 +231,32 @@ func (p *Peer) AssertEventuallyResponsesClosed() { } } +func (p *Peer) AssertNotClosed(d time.Duration) { + p.t.Helper() + // nolint: gocritic // erroneously thinks we're hardcoding non testutil constants here + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() + for { + select { + case <-ctx.Done(): + // success! + return + case <-p.ctx.Done(): + p.t.Error("main ctx timeout before elapsed time") + return + case resp, ok := <-p.resps: + if !ok { + p.t.Error("response channel closed") + return + } + err := p.handleResp(resp) + if !assert.NoError(p.t, err) { + return + } + } + } +} + func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) { p.t.Helper() for { @@ -171,7 +265,7 @@ func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) { return } - err := p.handleOneResp() + err := p.readOneResp() if xerrors.Is(err, responsesClosed) { return } @@ -181,8 +275,9 @@ func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) { func (p *Peer) AssertEventuallyGetsError(match string) { p.t.Helper() for { - err := p.handleOneResp() + err := p.readOneResp() if xerrors.Is(err, responsesClosed) { + p.t.Error("closed before target error") return } @@ -207,7 +302,7 @@ func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateRespon var responsesClosed = xerrors.New("responses closed") -func (p *Peer) handleOneResp() error { +func (p *Peer) readOneResp() error { select { case <-p.ctx.Done(): return p.ctx.Err() @@ -215,31 +310,39 @@ func (p *Peer) handleOneResp() error { if !ok { return responsesClosed } - if resp.Error != "" { - return xerrors.New(resp.Error) + err := p.handleResp(resp) + if err != nil { + return err } - for _, update := range resp.PeerUpdates { - id, err := uuid.FromBytes(update.Id) - if err != nil { - return err - } - p.peerUpdates[id] = append(p.peerUpdates[id], update) + } + return nil +} - switch update.Kind { - case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST: - peer := p.peers[id] - peer.preferredDERP = update.GetNode().GetPreferredDerp() - peer.status = update.Kind - p.peers[id] = peer - case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: - delete(p.peers, id) - case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE: - peer := p.peers[id] - peer.readyForHandshake = true - p.peers[id] = peer - default: - return xerrors.Errorf("unhandled update kind %s", update.Kind) - } +func (p *Peer) handleResp(resp *proto.CoordinateResponse) error { + if resp.Error != "" { + return xerrors.New(resp.Error) + } + for _, update := range resp.PeerUpdates { + id, err := uuid.FromBytes(update.Id) + if err != nil { + return err + } + p.peerUpdates[id] = append(p.peerUpdates[id], update) + + switch update.Kind { + case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST: + peer := p.peers[id] + peer.preferredDERP = update.GetNode().GetPreferredDerp() + peer.status = update.Kind + p.peers[id] = peer + case proto.CoordinateResponse_PeerUpdate_DISCONNECTED: + delete(p.peers, id) + case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE: + peer := p.peers[id] + peer.readyForHandshake = true + p.peers[id] = peer + default: + return xerrors.Errorf("unhandled update kind %s", update.Kind) } } return nil @@ -261,3 +364,9 @@ func (p *Peer) Close(ctx context.Context) { } } } + +func (p *Peer) UngracefulDisconnect(ctx context.Context) { + p.t.Helper() + close(p.reqs) + p.Close(ctx) +} diff --git a/tailnet/trackedconn.go b/tailnet/trackedconn.go deleted file mode 100644 index a801cdfae0..0000000000 --- a/tailnet/trackedconn.go +++ /dev/null @@ -1,182 +0,0 @@ -package tailnet - -import ( - "bytes" - "context" - "encoding/json" - "net" - "sync/atomic" - "time" - - "github.com/google/uuid" - - "cdr.dev/slog" - "github.com/coder/coder/v2/tailnet/proto" -) - -const ( - // WriteTimeout is the amount of time we wait to write a node update to a connection before we - // declare it hung. It is exported so that tests can use it. - WriteTimeout = time.Second * 5 - // ResponseBufferSize is the max number of responses to buffer per connection before we start - // dropping updates - ResponseBufferSize = 512 - // RequestBufferSize is the max number of requests to buffer per connection - RequestBufferSize = 32 -) - -type TrackedConn struct { - ctx context.Context - cancel func() - kind QueueKind - conn net.Conn - updates chan *proto.CoordinateResponse - logger slog.Logger - lastData []byte - - // ID is an ephemeral UUID used to uniquely identify the owner of the - // connection. - id uuid.UUID - - name string - start int64 - lastWrite int64 - overwrites int64 -} - -func NewTrackedConn(ctx context.Context, cancel func(), - conn net.Conn, - id uuid.UUID, - logger slog.Logger, - name string, - overwrites int64, - kind QueueKind, -) *TrackedConn { - // buffer updates so they don't block, since we hold the - // coordinator mutex while queuing. Node updates don't - // come quickly, so 512 should be plenty for all but - // the most pathological cases. - updates := make(chan *proto.CoordinateResponse, ResponseBufferSize) - now := time.Now().Unix() - return &TrackedConn{ - ctx: ctx, - conn: conn, - cancel: cancel, - updates: updates, - logger: logger, - id: id, - start: now, - lastWrite: now, - name: name, - overwrites: overwrites, - kind: kind, - } -} - -func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) { - atomic.StoreInt64(&t.lastWrite, time.Now().Unix()) - select { - case t.updates <- resp: - return nil - default: - return ErrWouldBlock - } -} - -func (t *TrackedConn) UniqueID() uuid.UUID { - return t.id -} - -func (t *TrackedConn) Kind() QueueKind { - return t.kind -} - -func (t *TrackedConn) Name() string { - return t.name -} - -func (t *TrackedConn) Stats() (start, lastWrite int64) { - return t.start, atomic.LoadInt64(&t.lastWrite) -} - -func (t *TrackedConn) Overwrites() int64 { - return t.overwrites -} - -func (t *TrackedConn) CoordinatorClose() error { - return t.Close() -} - -func (t *TrackedConn) Done() <-chan struct{} { - return t.ctx.Done() -} - -// Close the connection and cancel the context for reading node updates from the queue -func (t *TrackedConn) Close() error { - t.cancel() - return t.conn.Close() -} - -// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is -// canceled. -func (t *TrackedConn) SendUpdates() { - for { - select { - case <-t.ctx.Done(): - t.logger.Debug(t.ctx, "done sending updates") - return - case resp := <-t.updates: - nodes, err := OnlyNodeUpdates(resp) - if err != nil { - t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err)) - return - } - if len(nodes) == 0 { - t.logger.Debug(t.ctx, "skipping response with no nodes") - continue - } - data, err := json.Marshal(nodes) - if err != nil { - t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) - return - } - if bytes.Equal(t.lastData, data) { - t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data))) - continue - } - - // Set a deadline so that hung connections don't put back pressure on the system. - // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. - err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err)) - _ = t.Close() - return - } - _, err = t.conn.Write(data) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "could not write nodes to connection", - slog.Error(err), slog.F("nodes", string(data))) - _ = t.Close() - return - } - t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data))) - - // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are - // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() - // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. - // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after - // our successful write, it is important that we reset the deadline before it fires. - err = t.conn.SetWriteDeadline(time.Time{}) - if err != nil { - // often, this is just because the connection is closed/broken, so only log at debug. - t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err)) - _ = t.Close() - return - } - t.lastData = data - } - } -}