mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
chore: remove tailnet v1 API support (#14641)
Drops support for v1 of the tailnet API, which was the original coordination protocol where we only sent node updates, never marked them lost or disconnected. v2 of the tailnet API went GA for CLI clients in Coder 2.8.0, so clients older than that would stop working.
This commit is contained in:
@ -10,8 +10,8 @@ import (
|
|||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||||
agpl "github.com/coder/coder/v2/tailnet"
|
|
||||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||||
|
agpltest "github.com/coder/coder/v2/tailnet/test"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,27 +35,27 @@ func TestPGCoordinator_MultiAgent(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord1.Close()
|
defer coord1.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
agent1.UpdateDERP(5)
|
||||||
|
|
||||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||||
defer ma1.Close()
|
defer ma1.Close()
|
||||||
|
|
||||||
ma1.RequireSubscribeAgent(agent1.id)
|
ma1.RequireSubscribeAgent(agent1.ID)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||||
|
|
||||||
ma1.SendNodeWithDERP(3)
|
ma1.SendNodeWithDERP(3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||||
|
|
||||||
ma1.Close()
|
ma1.Close()
|
||||||
require.NoError(t, agent1.close())
|
agent1.UngracefulDisconnect(ctx)
|
||||||
|
|
||||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
|
func TestPGCoordinator_MultiAgent_CoordClose(t *testing.T) {
|
||||||
@ -102,28 +102,28 @@ func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord1.Close()
|
defer coord1.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
agent1.UpdateDERP(5)
|
||||||
|
|
||||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||||
defer ma1.Close()
|
defer ma1.Close()
|
||||||
|
|
||||||
ma1.RequireSubscribeAgent(agent1.id)
|
ma1.RequireSubscribeAgent(agent1.ID)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||||
|
|
||||||
ma1.SendNodeWithDERP(3)
|
ma1.SendNodeWithDERP(3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||||
|
|
||||||
ma1.RequireUnsubscribeAgent(agent1.id)
|
ma1.RequireUnsubscribeAgent(agent1.ID)
|
||||||
ma1.Close()
|
ma1.Close()
|
||||||
require.NoError(t, agent1.close())
|
agent1.UngracefulDisconnect(ctx)
|
||||||
|
|
||||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
|
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
|
||||||
@ -147,43 +147,43 @@ func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord1.Close()
|
defer coord1.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
agent1.UpdateDERP(5)
|
||||||
|
|
||||||
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
ma1 := tailnettest.NewTestMultiAgent(t, coord1)
|
||||||
defer ma1.Close()
|
defer ma1.Close()
|
||||||
|
|
||||||
ma1.RequireSubscribeAgent(agent1.id)
|
ma1.RequireSubscribeAgent(agent1.ID)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||||
|
|
||||||
ma1.SendNodeWithDERP(3)
|
ma1.SendNodeWithDERP(3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||||
|
|
||||||
ma1.RequireUnsubscribeAgent(agent1.id)
|
ma1.RequireUnsubscribeAgent(agent1.ID)
|
||||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ma1.SendNodeWithDERP(9)
|
ma1.SendNodeWithDERP(9)
|
||||||
assertNeverHasDERPs(ctx, t, agent1, 9)
|
agent1.AssertNeverHasDERPs(ctx, ma1.ID, 9)
|
||||||
}()
|
}()
|
||||||
func() {
|
func() {
|
||||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
|
agent1.UpdateDERP(8)
|
||||||
ma1.RequireNeverHasDERPs(ctx, 8)
|
ma1.RequireNeverHasDERPs(ctx, 8)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ma1.Close()
|
ma1.Close()
|
||||||
require.NoError(t, agent1.close())
|
agent1.UngracefulDisconnect(ctx)
|
||||||
|
|
||||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
|
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
|
||||||
@ -212,27 +212,27 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord2.Close()
|
defer coord2.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
agent1.UpdateDERP(5)
|
||||||
|
|
||||||
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
|
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
|
||||||
defer ma1.Close()
|
defer ma1.Close()
|
||||||
|
|
||||||
ma1.RequireSubscribeAgent(agent1.id)
|
ma1.RequireSubscribeAgent(agent1.ID)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||||
|
|
||||||
ma1.SendNodeWithDERP(3)
|
ma1.SendNodeWithDERP(3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||||
|
|
||||||
ma1.Close()
|
ma1.Close()
|
||||||
require.NoError(t, agent1.close())
|
agent1.UngracefulDisconnect(ctx)
|
||||||
|
|
||||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
|
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
|
||||||
@ -262,27 +262,27 @@ func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *test
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord2.Close()
|
defer coord2.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
agent1.UpdateDERP(5)
|
||||||
|
|
||||||
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
|
ma1 := tailnettest.NewTestMultiAgent(t, coord2)
|
||||||
defer ma1.Close()
|
defer ma1.Close()
|
||||||
|
|
||||||
ma1.SendNodeWithDERP(3)
|
ma1.SendNodeWithDERP(3)
|
||||||
|
|
||||||
ma1.RequireSubscribeAgent(agent1.id)
|
ma1.RequireSubscribeAgent(agent1.ID)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||||
|
|
||||||
ma1.Close()
|
ma1.Close()
|
||||||
require.NoError(t, agent1.close())
|
agent1.UngracefulDisconnect(ctx)
|
||||||
|
|
||||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
|
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
|
||||||
@ -317,37 +317,37 @@ func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord3.Close()
|
defer coord3.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
agent1.UpdateDERP(5)
|
||||||
|
|
||||||
agent2 := newTestAgent(t, coord2, "agent2")
|
agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2")
|
||||||
defer agent1.close()
|
defer agent2.Close(ctx)
|
||||||
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
|
agent2.UpdateDERP(6)
|
||||||
|
|
||||||
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
|
ma1 := tailnettest.NewTestMultiAgent(t, coord3)
|
||||||
defer ma1.Close()
|
defer ma1.Close()
|
||||||
|
|
||||||
ma1.RequireSubscribeAgent(agent1.id)
|
ma1.RequireSubscribeAgent(agent1.ID)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
ma1.RequireEventuallyHasDERPs(ctx, 5)
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
ma1.RequireEventuallyHasDERPs(ctx, 1)
|
||||||
|
|
||||||
ma1.RequireSubscribeAgent(agent2.id)
|
ma1.RequireSubscribeAgent(agent2.ID)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 6)
|
ma1.RequireEventuallyHasDERPs(ctx, 6)
|
||||||
|
|
||||||
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
|
agent2.UpdateDERP(2)
|
||||||
ma1.RequireEventuallyHasDERPs(ctx, 2)
|
ma1.RequireEventuallyHasDERPs(ctx, 2)
|
||||||
|
|
||||||
ma1.SendNodeWithDERP(3)
|
ma1.SendNodeWithDERP(3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
agent1.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent2, 3)
|
agent2.AssertEventuallyHasDERP(ma1.ID, 3)
|
||||||
|
|
||||||
ma1.Close()
|
ma1.Close()
|
||||||
require.NoError(t, agent1.close())
|
agent1.UngracefulDisconnect(ctx)
|
||||||
require.NoError(t, agent2.close())
|
agent2.UngracefulDisconnect(ctx)
|
||||||
|
|
||||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package tailnet
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -213,14 +212,6 @@ func (c *pgCoord) Node(id uuid.UUID) *agpl.Node {
|
|||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
|
||||||
return agpl.ServeClientV1(c.ctx, c.logger, c, conn, id, agent)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
|
|
||||||
return agpl.ServeAgentV1(c.ctx, c.logger, c, conn, id, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *pgCoord) Close() error {
|
func (c *pgCoord) Close() error {
|
||||||
c.logger.Info(c.ctx, "closing coordinator")
|
c.logger.Info(c.ctx, "closing coordinator")
|
||||||
c.cancel()
|
c.cancel()
|
||||||
|
@ -3,8 +3,6 @@ package tailnet_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@ -15,7 +13,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
"go.uber.org/mock/gomock"
|
"go.uber.org/mock/gomock"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
gProto "google.golang.org/protobuf/proto"
|
gProto "google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
@ -51,9 +48,9 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
|
|||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agentID := uuid.New()
|
agentID := uuid.New()
|
||||||
client := newTestClient(t, coordinator, agentID)
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agentID)
|
||||||
defer client.close()
|
defer client.Close(ctx)
|
||||||
client.sendNode(&agpl.Node{PreferredDERP: 10})
|
client.UpdateDERP(10)
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID)
|
clients, err := store.GetTailnetTunnelPeerBindings(ctx, agentID)
|
||||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||||
@ -68,12 +65,8 @@ func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
|
|||||||
assert.EqualValues(t, 10, node.PreferredDerp)
|
assert.EqualValues(t, 10, node.PreferredDerp)
|
||||||
return true
|
return true
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
client.UngracefulDisconnect(ctx)
|
||||||
err = client.close()
|
assertEventuallyLost(ctx, t, store, client.ID)
|
||||||
require.NoError(t, err)
|
|
||||||
<-client.errChan
|
|
||||||
<-client.closeChan
|
|
||||||
assertEventuallyLost(ctx, t, store, client.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
||||||
@ -89,11 +82,11 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agent := newTestAgent(t, coordinator, "agent")
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
||||||
defer agent.close()
|
defer agent.Close(ctx)
|
||||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
agent.UpdateDERP(10)
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
agents, err := store.GetTailnetPeers(ctx, agent.id)
|
agents, err := store.GetTailnetPeers(ctx, agent.ID)
|
||||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||||
t.Fatalf("database error: %v", err)
|
t.Fatalf("database error: %v", err)
|
||||||
}
|
}
|
||||||
@ -106,11 +99,8 @@ func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
|
|||||||
assert.EqualValues(t, 10, node.PreferredDerp)
|
assert.EqualValues(t, 10, node.PreferredDerp)
|
||||||
return true
|
return true
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
err = agent.close()
|
agent.UngracefulDisconnect(ctx)
|
||||||
require.NoError(t, err)
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
||||||
<-agent.errChan
|
|
||||||
<-agent.closeChan
|
|
||||||
assertEventuallyLost(ctx, t, store, agent.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
|
func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
|
||||||
@ -126,18 +116,18 @@ func TestPGCoordinatorSingle_AgentInvalidIP(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agent := newTestAgent(t, coordinator, "agent")
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
||||||
defer agent.close()
|
defer agent.Close(ctx)
|
||||||
agent.sendNode(&agpl.Node{
|
agent.UpdateNode(&proto.Node{
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []string{
|
||||||
netip.PrefixFrom(agpl.IP(), 128),
|
netip.PrefixFrom(agpl.IP(), 128).String(),
|
||||||
},
|
},
|
||||||
PreferredDERP: 10,
|
PreferredDerp: 10,
|
||||||
})
|
})
|
||||||
|
|
||||||
// The agent connection should be closed immediately after sending an invalid addr
|
// The agent connection should be closed immediately after sending an invalid addr
|
||||||
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
|
agent.AssertEventuallyResponsesClosed()
|
||||||
assertEventuallyLost(ctx, t, store, agent.id)
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
|
func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
|
||||||
@ -153,18 +143,18 @@ func TestPGCoordinatorSingle_AgentInvalidIPBits(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agent := newTestAgent(t, coordinator, "agent")
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
||||||
defer agent.close()
|
defer agent.Close(ctx)
|
||||||
agent.sendNode(&agpl.Node{
|
agent.UpdateNode(&proto.Node{
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []string{
|
||||||
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 64),
|
netip.PrefixFrom(agpl.IPFromUUID(agent.ID), 64).String(),
|
||||||
},
|
},
|
||||||
PreferredDERP: 10,
|
PreferredDerp: 10,
|
||||||
})
|
})
|
||||||
|
|
||||||
// The agent connection should be closed immediately after sending an invalid addr
|
// The agent connection should be closed immediately after sending an invalid addr
|
||||||
testutil.RequireRecvCtx(ctx, t, agent.closeChan)
|
agent.AssertEventuallyResponsesClosed()
|
||||||
assertEventuallyLost(ctx, t, store, agent.id)
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
|
func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
|
||||||
@ -180,16 +170,16 @@ func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agent := newTestAgent(t, coordinator, "agent")
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
||||||
defer agent.close()
|
defer agent.Close(ctx)
|
||||||
agent.sendNode(&agpl.Node{
|
agent.UpdateNode(&proto.Node{
|
||||||
Addresses: []netip.Prefix{
|
Addresses: []string{
|
||||||
netip.PrefixFrom(agpl.IPFromUUID(agent.id), 128),
|
netip.PrefixFrom(agpl.IPFromUUID(agent.ID), 128).String(),
|
||||||
},
|
},
|
||||||
PreferredDERP: 10,
|
PreferredDerp: 10,
|
||||||
})
|
})
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
agents, err := store.GetTailnetPeers(ctx, agent.id)
|
agents, err := store.GetTailnetPeers(ctx, agent.ID)
|
||||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||||
t.Fatalf("database error: %v", err)
|
t.Fatalf("database error: %v", err)
|
||||||
}
|
}
|
||||||
@ -202,11 +192,8 @@ func TestPGCoordinatorSingle_AgentValidIP(t *testing.T) {
|
|||||||
assert.EqualValues(t, 10, node.PreferredDerp)
|
assert.EqualValues(t, 10, node.PreferredDerp)
|
||||||
return true
|
return true
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
err = agent.close()
|
agent.UngracefulDisconnect(ctx)
|
||||||
require.NoError(t, err)
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
||||||
<-agent.errChan
|
|
||||||
<-agent.closeChan
|
|
||||||
assertEventuallyLost(ctx, t, store, agent.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
||||||
@ -222,68 +209,40 @@ func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agent := newTestAgent(t, coordinator, "original")
|
agent := agpltest.NewAgent(ctx, t, coordinator, "original")
|
||||||
defer agent.close()
|
defer agent.Close(ctx)
|
||||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
agent.UpdateDERP(10)
|
||||||
|
|
||||||
client := newTestClient(t, coordinator, agent.id)
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
|
||||||
defer client.close()
|
defer client.Close(ctx)
|
||||||
|
|
||||||
agentNodes := client.recvNodes(ctx, t)
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
||||||
require.Len(t, agentNodes, 1)
|
client.UpdateDERP(11)
|
||||||
assert.Equal(t, 10, agentNodes[0].PreferredDERP)
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
||||||
client.sendNode(&agpl.Node{PreferredDERP: 11})
|
|
||||||
clientNodes := agent.recvNodes(ctx, t)
|
|
||||||
require.Len(t, clientNodes, 1)
|
|
||||||
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
|
|
||||||
|
|
||||||
// Ensure an update to the agent node reaches the connIO!
|
// Ensure an update to the agent node reaches the connIO!
|
||||||
agent.sendNode(&agpl.Node{PreferredDERP: 12})
|
agent.UpdateDERP(12)
|
||||||
agentNodes = client.recvNodes(ctx, t)
|
client.AssertEventuallyHasDERP(agent.ID, 12)
|
||||||
require.Len(t, agentNodes, 1)
|
|
||||||
assert.Equal(t, 12, agentNodes[0].PreferredDERP)
|
|
||||||
|
|
||||||
// Close the agent WebSocket so a new one can connect.
|
// Close the agent channel so a new one can connect.
|
||||||
err = agent.close()
|
agent.Close(ctx)
|
||||||
require.NoError(t, err)
|
|
||||||
_ = agent.recvErr(ctx, t)
|
|
||||||
agent.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
// Create a new agent connection. This is to simulate a reconnect!
|
// Create a new agent connection. This is to simulate a reconnect!
|
||||||
agent = newTestAgent(t, coordinator, "reconnection", agent.id)
|
agent = agpltest.NewPeer(ctx, t, coordinator, "reconnection", agpltest.WithID(agent.ID))
|
||||||
// Ensure the existing listening connIO sends its node immediately!
|
// Ensure the coordinator sends its client node immediately!
|
||||||
clientNodes = agent.recvNodes(ctx, t)
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
||||||
require.Len(t, clientNodes, 1)
|
|
||||||
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
|
|
||||||
|
|
||||||
// Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the
|
// Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the
|
||||||
// coordinator accidentally reordering things.
|
// coordinator accidentally reordering things.
|
||||||
for d := 13; d < 36; d++ {
|
for d := int32(13); d < 36; d++ {
|
||||||
agent.sendNode(&agpl.Node{PreferredDERP: d})
|
agent.UpdateDERP(d)
|
||||||
}
|
|
||||||
for {
|
|
||||||
nodes := client.recvNodes(ctx, t)
|
|
||||||
if !assert.Len(t, nodes, 1) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if nodes[0].PreferredDERP == 35 {
|
|
||||||
// got latest!
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
client.AssertEventuallyHasDERP(agent.ID, 35)
|
||||||
|
|
||||||
err = agent.close()
|
agent.UngracefulDisconnect(ctx)
|
||||||
require.NoError(t, err)
|
client.UngracefulDisconnect(ctx)
|
||||||
_ = agent.recvErr(ctx, t)
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
||||||
agent.waitForClose(ctx, t)
|
assertEventuallyLost(ctx, t, store, client.ID)
|
||||||
|
|
||||||
err = client.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
_ = client.recvErr(ctx, t)
|
|
||||||
client.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
assertEventuallyLost(ctx, t, store, agent.id)
|
|
||||||
assertEventuallyLost(ctx, t, store, client.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
||||||
@ -305,16 +264,16 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agent := newTestAgent(t, coordinator, "agent")
|
agent := agpltest.NewAgent(ctx, t, coordinator, "agent")
|
||||||
defer agent.close()
|
defer agent.Close(ctx)
|
||||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
agent.UpdateDERP(10)
|
||||||
|
|
||||||
client := newTestClient(t, coordinator, agent.id)
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
|
||||||
defer client.close()
|
defer client.Close(ctx)
|
||||||
|
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 10)
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
||||||
client.sendNode(&agpl.Node{PreferredDERP: 11})
|
client.UpdateDERP(11)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent, 11)
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
||||||
|
|
||||||
// simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a
|
// simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a
|
||||||
// real coordinator
|
// real coordinator
|
||||||
@ -328,8 +287,8 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
|||||||
fCoord2.heartbeat()
|
fCoord2.heartbeat()
|
||||||
afTrap.MustWait(ctx).Release() // heartbeat timeout started
|
afTrap.MustWait(ctx).Release() // heartbeat timeout started
|
||||||
|
|
||||||
fCoord2.agentNode(agent.id, &agpl.Node{PreferredDERP: 12})
|
fCoord2.agentNode(agent.ID, &agpl.Node{PreferredDERP: 12})
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 12)
|
client.AssertEventuallyHasDERP(agent.ID, 12)
|
||||||
|
|
||||||
fCoord3 := &fakeCoordinator{
|
fCoord3 := &fakeCoordinator{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@ -339,8 +298,8 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
|||||||
}
|
}
|
||||||
fCoord3.heartbeat()
|
fCoord3.heartbeat()
|
||||||
rstTrap.MustWait(ctx).Release() // timeout gets reset
|
rstTrap.MustWait(ctx).Release() // timeout gets reset
|
||||||
fCoord3.agentNode(agent.id, &agpl.Node{PreferredDERP: 13})
|
fCoord3.agentNode(agent.ID, &agpl.Node{PreferredDERP: 13})
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 13)
|
client.AssertEventuallyHasDERP(agent.ID, 13)
|
||||||
|
|
||||||
// fCoord2 sends in a second heartbeat, one period later (on time)
|
// fCoord2 sends in a second heartbeat, one period later (on time)
|
||||||
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
|
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
|
||||||
@ -353,30 +312,22 @@ func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
|
|||||||
w := mClock.Advance(tailnet.HeartbeatPeriod)
|
w := mClock.Advance(tailnet.HeartbeatPeriod)
|
||||||
rstTrap.MustWait(ctx).Release()
|
rstTrap.MustWait(ctx).Release()
|
||||||
w.MustWait(ctx)
|
w.MustWait(ctx)
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 12)
|
client.AssertEventuallyHasDERP(agent.ID, 12)
|
||||||
|
|
||||||
// one more heartbeat period will result in fCoord2 being expired, which should cause us to
|
// one more heartbeat period will result in fCoord2 being expired, which should cause us to
|
||||||
// revert to the original agent mapping
|
// revert to the original agent mapping
|
||||||
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
|
mClock.Advance(tailnet.HeartbeatPeriod).MustWait(ctx)
|
||||||
// note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired
|
// note that the timeout doesn't get reset because both fCoord2 and fCoord3 are expired
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 10)
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
||||||
|
|
||||||
// send fCoord3 heartbeat, which should trigger us to consider that mapping valid again.
|
// send fCoord3 heartbeat, which should trigger us to consider that mapping valid again.
|
||||||
fCoord3.heartbeat()
|
fCoord3.heartbeat()
|
||||||
rstTrap.MustWait(ctx).Release() // timeout gets reset
|
rstTrap.MustWait(ctx).Release() // timeout gets reset
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 13)
|
client.AssertEventuallyHasDERP(agent.ID, 13)
|
||||||
|
|
||||||
err = agent.close()
|
agent.UngracefulDisconnect(ctx)
|
||||||
require.NoError(t, err)
|
client.UngracefulDisconnect(ctx)
|
||||||
_ = agent.recvErr(ctx, t)
|
assertEventuallyLost(ctx, t, store, client.ID)
|
||||||
agent.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
err = client.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
_ = client.recvErr(ctx, t)
|
|
||||||
client.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
assertEventuallyLost(ctx, t, store, client.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) {
|
func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) {
|
||||||
@ -420,7 +371,7 @@ func TestPGCoordinatorSingle_MissedHeartbeats_NoDrop(t *testing.T) {
|
|||||||
// disconnect.
|
// disconnect.
|
||||||
client.AssertEventuallyLost(agentID)
|
client.AssertEventuallyLost(agentID)
|
||||||
|
|
||||||
client.Close(ctx)
|
client.UngracefulDisconnect(ctx)
|
||||||
|
|
||||||
assertEventuallyLost(ctx, t, store, client.ID)
|
assertEventuallyLost(ctx, t, store, client.ID)
|
||||||
}
|
}
|
||||||
@ -491,104 +442,73 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord2.Close()
|
defer coord2.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
t.Logf("agent1=%s", agent1.id)
|
t.Logf("agent1=%s", agent1.ID)
|
||||||
agent2 := newTestAgent(t, coord2, "agent2")
|
agent2 := agpltest.NewAgent(ctx, t, coord2, "agent2")
|
||||||
defer agent2.close()
|
defer agent2.Close(ctx)
|
||||||
t.Logf("agent2=%s", agent2.id)
|
t.Logf("agent2=%s", agent2.ID)
|
||||||
|
|
||||||
client11 := newTestClient(t, coord1, agent1.id)
|
client11 := agpltest.NewClient(ctx, t, coord1, "client11", agent1.ID)
|
||||||
defer client11.close()
|
defer client11.Close(ctx)
|
||||||
t.Logf("client11=%s", client11.id)
|
t.Logf("client11=%s", client11.ID)
|
||||||
client12 := newTestClient(t, coord1, agent2.id)
|
client12 := agpltest.NewClient(ctx, t, coord1, "client12", agent2.ID)
|
||||||
defer client12.close()
|
defer client12.Close(ctx)
|
||||||
t.Logf("client12=%s", client12.id)
|
t.Logf("client12=%s", client12.ID)
|
||||||
client21 := newTestClient(t, coord2, agent1.id)
|
client21 := agpltest.NewClient(ctx, t, coord2, "client21", agent1.ID)
|
||||||
defer client21.close()
|
defer client21.Close(ctx)
|
||||||
t.Logf("client21=%s", client21.id)
|
t.Logf("client21=%s", client21.ID)
|
||||||
client22 := newTestClient(t, coord2, agent2.id)
|
client22 := agpltest.NewClient(ctx, t, coord2, "client22", agent2.ID)
|
||||||
defer client22.close()
|
defer client22.Close(ctx)
|
||||||
t.Logf("client22=%s", client22.id)
|
t.Logf("client22=%s", client22.ID)
|
||||||
|
|
||||||
t.Logf("client11 -> Node 11")
|
t.Logf("client11 -> Node 11")
|
||||||
client11.sendNode(&agpl.Node{PreferredDERP: 11})
|
client11.UpdateDERP(11)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 11)
|
agent1.AssertEventuallyHasDERP(client11.ID, 11)
|
||||||
|
|
||||||
t.Logf("client21 -> Node 21")
|
t.Logf("client21 -> Node 21")
|
||||||
client21.sendNode(&agpl.Node{PreferredDERP: 21})
|
client21.UpdateDERP(21)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 21)
|
agent1.AssertEventuallyHasDERP(client21.ID, 21)
|
||||||
|
|
||||||
t.Logf("client22 -> Node 22")
|
t.Logf("client22 -> Node 22")
|
||||||
client22.sendNode(&agpl.Node{PreferredDERP: 22})
|
client22.UpdateDERP(22)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent2, 22)
|
agent2.AssertEventuallyHasDERP(client22.ID, 22)
|
||||||
|
|
||||||
t.Logf("agent2 -> Node 2")
|
t.Logf("agent2 -> Node 2")
|
||||||
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
|
agent2.UpdateDERP(2)
|
||||||
assertEventuallyHasDERPs(ctx, t, client22, 2)
|
client22.AssertEventuallyHasDERP(agent2.ID, 2)
|
||||||
assertEventuallyHasDERPs(ctx, t, client12, 2)
|
client12.AssertEventuallyHasDERP(agent2.ID, 2)
|
||||||
|
|
||||||
t.Logf("client12 -> Node 12")
|
t.Logf("client12 -> Node 12")
|
||||||
client12.sendNode(&agpl.Node{PreferredDERP: 12})
|
client12.UpdateDERP(12)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent2, 12)
|
agent2.AssertEventuallyHasDERP(client12.ID, 12)
|
||||||
|
|
||||||
t.Logf("agent1 -> Node 1")
|
t.Logf("agent1 -> Node 1")
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
assertEventuallyHasDERPs(ctx, t, client21, 1)
|
client21.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||||
assertEventuallyHasDERPs(ctx, t, client11, 1)
|
client11.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||||
|
|
||||||
t.Logf("close coord2")
|
t.Logf("close coord2")
|
||||||
err = coord2.Close()
|
err = coord2.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// this closes agent2, client22, client21
|
// this closes agent2, client22, client21
|
||||||
err = agent2.recvErr(ctx, t)
|
agent2.AssertEventuallyResponsesClosed()
|
||||||
require.ErrorIs(t, err, io.EOF)
|
client22.AssertEventuallyResponsesClosed()
|
||||||
err = client22.recvErr(ctx, t)
|
client21.AssertEventuallyResponsesClosed()
|
||||||
require.ErrorIs(t, err, io.EOF)
|
assertEventuallyLost(ctx, t, store, agent2.ID)
|
||||||
err = client21.recvErr(ctx, t)
|
assertEventuallyLost(ctx, t, store, client21.ID)
|
||||||
require.ErrorIs(t, err, io.EOF)
|
assertEventuallyLost(ctx, t, store, client22.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent2.id)
|
|
||||||
assertEventuallyLost(ctx, t, store, client21.id)
|
|
||||||
assertEventuallyLost(ctx, t, store, client22.id)
|
|
||||||
|
|
||||||
err = coord1.Close()
|
err = coord1.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// this closes agent1, client12, client11
|
// this closes agent1, client12, client11
|
||||||
err = agent1.recvErr(ctx, t)
|
agent1.AssertEventuallyResponsesClosed()
|
||||||
require.ErrorIs(t, err, io.EOF)
|
client12.AssertEventuallyResponsesClosed()
|
||||||
err = client12.recvErr(ctx, t)
|
client11.AssertEventuallyResponsesClosed()
|
||||||
require.ErrorIs(t, err, io.EOF)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
err = client11.recvErr(ctx, t)
|
assertEventuallyLost(ctx, t, store, client11.ID)
|
||||||
require.ErrorIs(t, err, io.EOF)
|
assertEventuallyLost(ctx, t, store, client12.ID)
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
|
||||||
assertEventuallyLost(ctx, t, store, client11.id)
|
|
||||||
assertEventuallyLost(ctx, t, store, client12.id)
|
|
||||||
|
|
||||||
// wait for all connections to close
|
|
||||||
err = agent1.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
agent1.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
err = agent2.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
agent2.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
err = client11.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
client11.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
err = client12.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
client12.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
err = client21.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
client21.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
err = client22.close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
client22.waitForClose(ctx, t)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
|
// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
|
||||||
@ -623,53 +543,42 @@ func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coord3.Close()
|
defer coord3.Close()
|
||||||
|
|
||||||
agent1 := newTestAgent(t, coord1, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, coord1, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
agent2 := newTestAgent(t, coord2, "agent2", agent1.id)
|
agent2 := agpltest.NewPeer(ctx, t, coord2, "agent2",
|
||||||
defer agent2.close()
|
agpltest.WithID(agent1.ID), agpltest.WithAuth(agpl.AgentCoordinateeAuth{ID: agent1.ID}),
|
||||||
|
)
|
||||||
|
defer agent2.Close(ctx)
|
||||||
|
|
||||||
client := newTestClient(t, coord3, agent1.id)
|
client := agpltest.NewClient(ctx, t, coord3, "client", agent1.ID)
|
||||||
defer client.close()
|
defer client.Close(ctx)
|
||||||
|
|
||||||
client.sendNode(&agpl.Node{PreferredDERP: 3})
|
client.UpdateDERP(3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
agent1.AssertEventuallyHasDERP(client.ID, 3)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent2, 3)
|
agent2.AssertEventuallyHasDERP(client.ID, 3)
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
agent1.UpdateDERP(1)
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 1)
|
client.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||||
|
|
||||||
// agent2's update overrides agent1 because it is newer
|
// agent2's update overrides agent1 because it is newer
|
||||||
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
|
agent2.UpdateDERP(2)
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 2)
|
client.AssertEventuallyHasDERP(agent1.ID, 2)
|
||||||
|
|
||||||
// agent2 disconnects, and we should revert back to agent1
|
// agent2 disconnects, and we should revert back to agent1
|
||||||
err = agent2.close()
|
agent2.Close(ctx)
|
||||||
require.NoError(t, err)
|
client.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||||
err = agent2.recvErr(ctx, t)
|
|
||||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
|
||||||
agent2.waitForClose(ctx, t)
|
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 1)
|
|
||||||
|
|
||||||
agent1.sendNode(&agpl.Node{PreferredDERP: 11})
|
agent1.UpdateDERP(11)
|
||||||
assertEventuallyHasDERPs(ctx, t, client, 11)
|
client.AssertEventuallyHasDERP(agent1.ID, 11)
|
||||||
|
|
||||||
client.sendNode(&agpl.Node{PreferredDERP: 31})
|
client.UpdateDERP(31)
|
||||||
assertEventuallyHasDERPs(ctx, t, agent1, 31)
|
agent1.AssertEventuallyHasDERP(client.ID, 31)
|
||||||
|
|
||||||
err = agent1.close()
|
agent1.UngracefulDisconnect(ctx)
|
||||||
require.NoError(t, err)
|
client.UngracefulDisconnect(ctx)
|
||||||
err = agent1.recvErr(ctx, t)
|
|
||||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
|
||||||
agent1.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
err = client.close()
|
assertEventuallyLost(ctx, t, store, client.ID)
|
||||||
require.NoError(t, err)
|
assertEventuallyLost(ctx, t, store, agent1.ID)
|
||||||
err = client.recvErr(ctx, t)
|
|
||||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
|
||||||
client.waitForClose(ctx, t)
|
|
||||||
|
|
||||||
assertEventuallyLost(ctx, t, store, client.id)
|
|
||||||
assertEventuallyLost(ctx, t, store, agent1.id)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinator_Unhealthy(t *testing.T) {
|
func TestPGCoordinator_Unhealthy(t *testing.T) {
|
||||||
@ -683,7 +592,13 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||||
|
|
||||||
calls := make(chan struct{})
|
calls := make(chan struct{})
|
||||||
|
// first call succeeds, so that our Agent will successfully connect.
|
||||||
|
firstSucceeds := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
||||||
|
Times(1).
|
||||||
|
Return(database.TailnetCoordinator{}, nil)
|
||||||
|
// next 3 fail, so the Coordinator becomes unhealthy, and we test that it disconnects the agent
|
||||||
threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
threeMissed := mStore.EXPECT().UpsertTailnetCoordinator(gomock.Any(), gomock.Any()).
|
||||||
|
After(firstSucceeds).
|
||||||
Times(3).
|
Times(3).
|
||||||
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
|
Do(func(_ context.Context, _ uuid.UUID) { <-calls }).
|
||||||
Return(database.TailnetCoordinator{}, xerrors.New("test disconnect"))
|
Return(database.TailnetCoordinator{}, xerrors.New("test disconnect"))
|
||||||
@ -710,23 +625,23 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||||||
err := uut.Close()
|
err := uut.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
agent1 := newTestAgent(t, uut, "agent1")
|
agent1 := agpltest.NewAgent(ctx, t, uut, "agent1")
|
||||||
defer agent1.close()
|
defer agent1.Close(ctx)
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatalf("timeout waiting for call %d", i+1)
|
||||||
case calls <- struct{}{}:
|
case calls <- struct{}{}:
|
||||||
// OK
|
// OK
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// connected agent should be disconnected
|
// connected agent should be disconnected
|
||||||
agent1.waitForClose(ctx, t)
|
agent1.AssertEventuallyResponsesClosed()
|
||||||
|
|
||||||
// new agent should immediately disconnect
|
// new agent should immediately disconnect
|
||||||
agent2 := newTestAgent(t, uut, "agent2")
|
agent2 := agpltest.NewAgent(ctx, t, uut, "agent2")
|
||||||
defer agent2.close()
|
defer agent2.Close(ctx)
|
||||||
agent2.waitForClose(ctx, t)
|
agent2.AssertEventuallyResponsesClosed()
|
||||||
|
|
||||||
// next heartbeats succeed, so we are healthy
|
// next heartbeats succeed, so we are healthy
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
@ -737,14 +652,9 @@ func TestPGCoordinator_Unhealthy(t *testing.T) {
|
|||||||
// OK
|
// OK
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
agent3 := newTestAgent(t, uut, "agent3")
|
agent3 := agpltest.NewAgent(ctx, t, uut, "agent3")
|
||||||
defer agent3.close()
|
defer agent3.Close(ctx)
|
||||||
select {
|
agent3.AssertNotClosed(time.Second)
|
||||||
case <-agent3.closeChan:
|
|
||||||
t.Fatal("agent conn closed after we are healthy")
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
// OK
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPGCoordinator_Node_Empty(t *testing.T) {
|
func TestPGCoordinator_Node_Empty(t *testing.T) {
|
||||||
@ -840,43 +750,39 @@ func TestPGCoordinator_NoDeleteOnClose(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator.Close()
|
defer coordinator.Close()
|
||||||
|
|
||||||
agent := newTestAgent(t, coordinator, "original")
|
agent := agpltest.NewAgent(ctx, t, coordinator, "original")
|
||||||
defer agent.close()
|
defer agent.Close(ctx)
|
||||||
agent.sendNode(&agpl.Node{PreferredDERP: 10})
|
agent.UpdateDERP(10)
|
||||||
|
|
||||||
client := newTestClient(t, coordinator, agent.id)
|
client := agpltest.NewClient(ctx, t, coordinator, "client", agent.ID)
|
||||||
defer client.close()
|
defer client.Close(ctx)
|
||||||
|
|
||||||
// Simulate some traffic to generate
|
// Simulate some traffic to generate
|
||||||
// a peer.
|
// a peer.
|
||||||
agentNodes := client.recvNodes(ctx, t)
|
client.AssertEventuallyHasDERP(agent.ID, 10)
|
||||||
require.Len(t, agentNodes, 1)
|
client.UpdateDERP(11)
|
||||||
assert.Equal(t, 10, agentNodes[0].PreferredDERP)
|
|
||||||
client.sendNode(&agpl.Node{PreferredDERP: 11})
|
|
||||||
|
|
||||||
clientNodes := agent.recvNodes(ctx, t)
|
agent.AssertEventuallyHasDERP(client.ID, 11)
|
||||||
require.Len(t, clientNodes, 1)
|
|
||||||
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
|
|
||||||
|
|
||||||
anode := coordinator.Node(agent.id)
|
anode := coordinator.Node(agent.ID)
|
||||||
require.NotNil(t, anode)
|
require.NotNil(t, anode)
|
||||||
cnode := coordinator.Node(client.id)
|
cnode := coordinator.Node(client.ID)
|
||||||
require.NotNil(t, cnode)
|
require.NotNil(t, cnode)
|
||||||
|
|
||||||
err = coordinator.Close()
|
err = coordinator.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertEventuallyLost(ctx, t, store, agent.id)
|
assertEventuallyLost(ctx, t, store, agent.ID)
|
||||||
assertEventuallyLost(ctx, t, store, client.id)
|
assertEventuallyLost(ctx, t, store, client.ID)
|
||||||
|
|
||||||
coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
coordinator2, err := tailnet.NewPGCoord(ctx, logger, ps, store)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer coordinator2.Close()
|
defer coordinator2.Close()
|
||||||
|
|
||||||
anode = coordinator2.Node(agent.id)
|
anode = coordinator2.Node(agent.ID)
|
||||||
require.NotNil(t, anode)
|
require.NotNil(t, anode)
|
||||||
assert.Equal(t, 10, anode.PreferredDERP)
|
assert.Equal(t, 10, anode.PreferredDERP)
|
||||||
|
|
||||||
cnode = coordinator2.Node(client.id)
|
cnode = coordinator2.Node(client.ID)
|
||||||
require.NotNil(t, cnode)
|
require.NotNil(t, cnode)
|
||||||
assert.Equal(t, 11, cnode.PreferredDERP)
|
assert.Equal(t, 11, cnode.PreferredDERP)
|
||||||
}
|
}
|
||||||
@ -1007,144 +913,6 @@ func TestPGCoordinatorDual_PeerReconnect(t *testing.T) {
|
|||||||
p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED)
|
p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testConn struct {
|
|
||||||
ws, serverWS net.Conn
|
|
||||||
nodeChan chan []*agpl.Node
|
|
||||||
sendNode func(node *agpl.Node)
|
|
||||||
errChan <-chan error
|
|
||||||
id uuid.UUID
|
|
||||||
closeChan chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestConn(ids []uuid.UUID) *testConn {
|
|
||||||
a := &testConn{}
|
|
||||||
a.ws, a.serverWS = net.Pipe()
|
|
||||||
a.nodeChan = make(chan []*agpl.Node)
|
|
||||||
a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error {
|
|
||||||
a.nodeChan <- nodes
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if len(ids) > 1 {
|
|
||||||
panic("too many")
|
|
||||||
}
|
|
||||||
if len(ids) == 1 {
|
|
||||||
a.id = ids[0]
|
|
||||||
} else {
|
|
||||||
a.id = uuid.New()
|
|
||||||
}
|
|
||||||
a.closeChan = make(chan struct{})
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestAgent(t *testing.T, coord agpl.CoordinatorV1, name string, id ...uuid.UUID) *testConn {
|
|
||||||
a := newTestConn(id)
|
|
||||||
go func() {
|
|
||||||
err := coord.ServeAgent(a.serverWS, a.id, name)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(a.closeChan)
|
|
||||||
}()
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestClient(t *testing.T, coord agpl.CoordinatorV1, agentID uuid.UUID, id ...uuid.UUID) *testConn {
|
|
||||||
c := newTestConn(id)
|
|
||||||
go func() {
|
|
||||||
err := coord.ServeClient(c.serverWS, c.id, agentID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(c.closeChan)
|
|
||||||
}()
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testConn) close() error {
|
|
||||||
return c.ws.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node {
|
|
||||||
t.Helper()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatalf("testConn id %s: timeout receiving nodes ", c.id)
|
|
||||||
return nil
|
|
||||||
case nodes := <-c.nodeChan:
|
|
||||||
return nodes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testConn) recvErr(ctx context.Context, t *testing.T) error {
|
|
||||||
t.Helper()
|
|
||||||
// pgCoord works on eventual consistency, so it sometimes sends extra node
|
|
||||||
// updates, and these block errors if not read from the nodes channel.
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case nodes := <-c.nodeChan:
|
|
||||||
t.Logf("ignoring nodes update while waiting for error; id=%s, nodes=%+v",
|
|
||||||
c.id.String(), nodes)
|
|
||||||
continue
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatal("timeout receiving error")
|
|
||||||
return ctx.Err()
|
|
||||||
case err := <-c.errChan:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testConn) waitForClose(ctx context.Context, t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatal("timeout waiting for connection to close")
|
|
||||||
return
|
|
||||||
case <-c.closeChan:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
|
|
||||||
t.Helper()
|
|
||||||
for {
|
|
||||||
nodes := c.recvNodes(ctx, t)
|
|
||||||
if len(nodes) != len(expected) {
|
|
||||||
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
derps := make([]int, 0, len(nodes))
|
|
||||||
for _, n := range nodes {
|
|
||||||
derps = append(derps, n.PreferredDERP)
|
|
||||||
}
|
|
||||||
for _, e := range expected {
|
|
||||||
if !slices.Contains(derps, e) {
|
|
||||||
t.Logf("expected DERP %d to be in %v", e, derps)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
|
|
||||||
t.Helper()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case nodes := <-c.nodeChan:
|
|
||||||
derps := make([]int, 0, len(nodes))
|
|
||||||
for _, n := range nodes {
|
|
||||||
derps = append(derps, n.PreferredDERP)
|
|
||||||
}
|
|
||||||
for _, e := range expected {
|
|
||||||
if slices.Contains(derps, e) {
|
|
||||||
t.Fatalf("expected not to get DERP %d, but received it", e)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) {
|
func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
package tailnet
|
package tailnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/xerrors"
|
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"github.com/coder/coder/v2/apiversion"
|
"github.com/coder/coder/v2/apiversion"
|
||||||
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
|
||||||
agpl "github.com/coder/coder/v2/tailnet"
|
agpl "github.com/coder/coder/v2/tailnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,10 +32,6 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
switch major {
|
switch major {
|
||||||
case 1:
|
|
||||||
coord := *(s.CoordPtr.Load())
|
|
||||||
sub := coord.ServeMultiAgent(id)
|
|
||||||
return ServeWorkspaceProxy(ctx, conn, sub)
|
|
||||||
case 2:
|
case 2:
|
||||||
auth := agpl.SingleTailnetCoordinateeAuth{}
|
auth := agpl.SingleTailnetCoordinateeAuth{}
|
||||||
streamID := agpl.StreamID{
|
streamID := agpl.StreamID{
|
||||||
@ -52,103 +42,6 @@ func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version strin
|
|||||||
return s.ServeConnV2(ctx, conn, streamID)
|
return s.ServeConnV2(ctx, conn, streamID)
|
||||||
default:
|
default:
|
||||||
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
||||||
return xerrors.New("unsupported version")
|
return agpl.ErrUnsupportedVersion
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
|
|
||||||
go func() {
|
|
||||||
//nolint:staticcheck
|
|
||||||
err := forwardNodesToWorkspaceProxy(ctx, conn, ma)
|
|
||||||
//nolint:staticcheck
|
|
||||||
if err != nil {
|
|
||||||
_ = conn.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
decoder := json.NewDecoder(conn)
|
|
||||||
for {
|
|
||||||
var msg wsproxysdk.CoordinateMessage
|
|
||||||
err := decoder.Decode(&msg)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return xerrors.Errorf("read json: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg.Type {
|
|
||||||
case wsproxysdk.CoordinateMessageTypeSubscribe:
|
|
||||||
err := ma.SubscribeAgent(msg.AgentID)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("subscribe agent: %w", err)
|
|
||||||
}
|
|
||||||
case wsproxysdk.CoordinateMessageTypeUnsubscribe:
|
|
||||||
err := ma.UnsubscribeAgent(msg.AgentID)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("unsubscribe agent: %w", err)
|
|
||||||
}
|
|
||||||
case wsproxysdk.CoordinateMessageTypeNodeUpdate:
|
|
||||||
pn, err := agpl.NodeToProto(msg.Node)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = ma.UpdateSelf(pn)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("update self: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return xerrors.Errorf("unknown message type %q", msg.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Linter fails because this function always returns an error. This function blocks
|
|
||||||
// until it errors, so this is ok.
|
|
||||||
//
|
|
||||||
//nolint:staticcheck
|
|
||||||
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
|
|
||||||
var lastData []byte
|
|
||||||
for {
|
|
||||||
resp, ok := ma.NextUpdate(ctx)
|
|
||||||
if !ok {
|
|
||||||
return xerrors.New("multiagent is closed")
|
|
||||||
}
|
|
||||||
nodes, err := agpl.OnlyNodeUpdates(resp)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("failed to convert response: %w", err)
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if bytes.Equal(lastData, data) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a deadline so that hung connections don't put back pressure on the system.
|
|
||||||
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
|
|
||||||
err = conn.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
|
|
||||||
if err != nil {
|
|
||||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = conn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
|
|
||||||
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
|
|
||||||
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
|
|
||||||
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
|
|
||||||
// our successful write, it is important that we reset the deadline before it fires.
|
|
||||||
err = conn.SetWriteDeadline(time.Time{})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
lastData = data
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
} from "../helpers";
|
} from "../helpers";
|
||||||
import { beforeCoderTest } from "../hooks";
|
import { beforeCoderTest } from "../hooks";
|
||||||
|
|
||||||
|
// we no longer support versions prior to Tailnet v2 API support: https://github.com/coder/coder/commit/059e533544a0268acbc8831006b2858ead2f0d8e
|
||||||
const clientVersion = "v2.8.0";
|
const clientVersion = "v2.8.0";
|
||||||
|
|
||||||
test.beforeEach(({ page }) => beforeCoderTest(page));
|
test.beforeEach(({ page }) => beforeCoderTest(page));
|
||||||
|
@ -2,11 +2,9 @@ package tailnet
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
@ -14,7 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
"nhooyr.io/websocket"
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
|
||||||
@ -22,35 +19,23 @@ import (
|
|||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ResponseBufferSize is the max number of responses to buffer per connection before we start
|
||||||
|
// dropping updates
|
||||||
|
ResponseBufferSize = 512
|
||||||
|
// RequestBufferSize is the max number of requests to buffer per connection
|
||||||
|
RequestBufferSize = 32
|
||||||
|
)
|
||||||
|
|
||||||
// Coordinator exchanges nodes with agents to establish connections.
|
// Coordinator exchanges nodes with agents to establish connections.
|
||||||
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
|
// ┌──────────────────┐ ┌────────────────────┐ ┌───────────────────┐ ┌──────────────────┐
|
||||||
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
|
// │tailnet.Coordinate├──►│tailnet.AcceptClient│◄─►│tailnet.AcceptAgent│◄──┤tailnet.Coordinate│
|
||||||
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
|
// └──────────────────┘ └────────────────────┘ └───────────────────┘ └──────────────────┘
|
||||||
// Coordinators have different guarantees for HA support.
|
// Coordinators have different guarantees for HA support.
|
||||||
type Coordinator interface {
|
type Coordinator interface {
|
||||||
CoordinatorV1
|
|
||||||
CoordinatorV2
|
CoordinatorV2
|
||||||
}
|
}
|
||||||
|
|
||||||
type CoordinatorV1 interface {
|
|
||||||
// ServeHTTPDebug serves a debug webpage that shows the internal state of
|
|
||||||
// the coordinator.
|
|
||||||
ServeHTTPDebug(w http.ResponseWriter, r *http.Request)
|
|
||||||
// Node returns an in-memory node by ID.
|
|
||||||
Node(id uuid.UUID) *Node
|
|
||||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
|
||||||
// with the specified ID.
|
|
||||||
ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error
|
|
||||||
// ServeAgent accepts a WebSocket connection to an agent that listens to
|
|
||||||
// incoming connections and publishes node updates.
|
|
||||||
// Name is just used for debug information. It can be left blank.
|
|
||||||
ServeAgent(conn net.Conn, id uuid.UUID, name string) error
|
|
||||||
// Close closes the coordinator.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
ServeMultiAgent(id uuid.UUID) MultiAgentConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API.
|
// CoordinatorV2 is the interface for interacting with the coordinator via the 2.0 tailnet API.
|
||||||
type CoordinatorV2 interface {
|
type CoordinatorV2 interface {
|
||||||
// ServeHTTPDebug serves a debug webpage that shows the internal state of
|
// ServeHTTPDebug serves a debug webpage that shows the internal state of
|
||||||
@ -60,6 +45,7 @@ type CoordinatorV2 interface {
|
|||||||
Node(id uuid.UUID) *Node
|
Node(id uuid.UUID) *Node
|
||||||
Close() error
|
Close() error
|
||||||
Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
|
Coordinate(ctx context.Context, id uuid.UUID, name string, a CoordinateeAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse)
|
||||||
|
ServeMultiAgent(id uuid.UUID) MultiAgentConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Node represents a node in the network.
|
// Node represents a node in the network.
|
||||||
@ -389,44 +375,6 @@ func (c *inMemoryCoordination) Close() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeCoordinator matches the RW structure of a coordinator to exchange node messages.
|
|
||||||
func ServeCoordinator(conn net.Conn, updateNodes func(node []*Node) error) (func(node *Node), <-chan error) {
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
sendErr := func(err error) {
|
|
||||||
select {
|
|
||||||
case errChan <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
decoder := json.NewDecoder(conn)
|
|
||||||
for {
|
|
||||||
var nodes []*Node
|
|
||||||
err := decoder.Decode(&nodes)
|
|
||||||
if err != nil {
|
|
||||||
sendErr(xerrors.Errorf("read: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = updateNodes(nodes)
|
|
||||||
if err != nil {
|
|
||||||
sendErr(xerrors.Errorf("update nodes: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return func(node *Node) {
|
|
||||||
data, err := json.Marshal(node)
|
|
||||||
if err != nil {
|
|
||||||
sendErr(xerrors.Errorf("marshal node: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = conn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
sendErr(xerrors.Errorf("write: %w", err))
|
|
||||||
}
|
|
||||||
}, errChan
|
|
||||||
}
|
|
||||||
|
|
||||||
const LoggerName = "coord"
|
const LoggerName = "coord"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -540,11 +488,11 @@ func ServeMultiAgent(c CoordinatorV2, logger slog.Logger, id uuid.UUID) MultiAge
|
|||||||
},
|
},
|
||||||
}).Init()
|
}).Init()
|
||||||
|
|
||||||
go v1RespLoop(ctx, cancel, logger, m, resps)
|
go qRespLoop(ctx, cancel, logger, m, resps)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// core is an in-memory structure of Node and TrackedConn mappings. Its methods may be called from multiple goroutines;
|
// core is an in-memory structure of peer mappings. Its methods may be called from multiple goroutines;
|
||||||
// it is protected by a mutex to ensure data stay consistent.
|
// it is protected by a mutex to ensure data stay consistent.
|
||||||
type core struct {
|
type core struct {
|
||||||
logger slog.Logger
|
logger slog.Logger
|
||||||
@ -607,42 +555,6 @@ func (c *core) node(id uuid.UUID) *Node {
|
|||||||
return v1Node
|
return v1Node
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeClient accepts a WebSocket connection that wants to connect to an agent
|
|
||||||
// with the specified ID.
|
|
||||||
func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
return ServeClientV1(ctx, c.core.logger, c, conn, id, agentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeClientV1 adapts a v1 Client to a v2 Coordinator
|
|
||||||
func ServeClientV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
|
||||||
logger = logger.With(slog.F("client_id", id), slog.F("agent_id", agent))
|
|
||||||
defer func() {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug(ctx, "closing client connection", slog.Error(err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
reqs, resps := c.Coordinate(ctx, id, id.String(), ClientCoordinateeAuth{AgentID: agent})
|
|
||||||
err := SendCtx(ctx, reqs, &proto.CoordinateRequest{
|
|
||||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: UUIDToByteSlice(agent)},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// can only be a context error, no need to log here.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient)
|
|
||||||
go tc.SendUpdates()
|
|
||||||
go v1RespLoop(ctx, cancel, logger, tc, resps)
|
|
||||||
go v1ReqLoop(ctx, cancel, logger, conn, reqs)
|
|
||||||
<-ctx.Done()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
|
func (c *core) handleRequest(p *peer, req *proto.CoordinateRequest) error {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
@ -887,34 +799,6 @@ func (c *core) removePeerLocked(id uuid.UUID, kind proto.CoordinateResponse_Peer
|
|||||||
delete(c.peers, id)
|
delete(c.peers, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeAgent accepts a WebSocket connection to an agent that
|
|
||||||
// listens to incoming connections and publishes node updates.
|
|
||||||
func (c *coordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) error {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
return ServeAgentV1(ctx, c.core.logger, c, conn, id, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ServeAgentV1(ctx context.Context, logger slog.Logger, c CoordinatorV2, conn net.Conn, id uuid.UUID, name string) error {
|
|
||||||
logger = logger.With(slog.F("agent_id", id), slog.F("name", name))
|
|
||||||
defer func() {
|
|
||||||
logger.Debug(ctx, "closing agent connection")
|
|
||||||
err := conn.Close()
|
|
||||||
logger.Debug(ctx, "closed agent connection", slog.Error(err))
|
|
||||||
}()
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
logger.Debug(ctx, "starting new agent connection")
|
|
||||||
reqs, resps := c.Coordinate(ctx, id, name, AgentCoordinateeAuth{ID: id})
|
|
||||||
tc := NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, QueueKindAgent)
|
|
||||||
go tc.SendUpdates()
|
|
||||||
go v1RespLoop(ctx, cancel, logger, tc, resps)
|
|
||||||
go v1ReqLoop(ctx, cancel, logger, conn, reqs)
|
|
||||||
<-ctx.Done()
|
|
||||||
logger.Debug(ctx, "ending agent connection")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes all of the open connections in the coordinator and stops the
|
// Close closes all of the open connections in the coordinator and stops the
|
||||||
// coordinator from accepting new connections.
|
// coordinator from accepting new connections.
|
||||||
func (c *coordinator) Close() error {
|
func (c *coordinator) Close() error {
|
||||||
@ -1073,44 +957,7 @@ func RecvCtx[A any](ctx context.Context, c <-chan A) (a A, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func v1ReqLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger,
|
func qRespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
|
||||||
conn net.Conn, reqs chan<- *proto.CoordinateRequest,
|
|
||||||
) {
|
|
||||||
defer close(reqs)
|
|
||||||
defer cancel()
|
|
||||||
decoder := json.NewDecoder(conn)
|
|
||||||
for {
|
|
||||||
var node Node
|
|
||||||
err := decoder.Decode(&node)
|
|
||||||
if err != nil {
|
|
||||||
if xerrors.Is(err, io.EOF) ||
|
|
||||||
xerrors.Is(err, io.ErrClosedPipe) ||
|
|
||||||
xerrors.Is(err, context.Canceled) ||
|
|
||||||
xerrors.Is(err, context.DeadlineExceeded) ||
|
|
||||||
websocket.CloseStatus(err) > 0 {
|
|
||||||
logger.Debug(ctx, "v1ReqLoop exiting", slog.Error(err))
|
|
||||||
} else {
|
|
||||||
logger.Info(ctx, "v1ReqLoop failed to decode Node update", slog.Error(err))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
logger.Debug(ctx, "v1ReqLoop got node update", slog.F("node", node))
|
|
||||||
pn, err := NodeToProto(&node)
|
|
||||||
if err != nil {
|
|
||||||
logger.Critical(ctx, "v1ReqLoop failed to convert v1 node", slog.F("node", node), slog.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{
|
|
||||||
Node: pn,
|
|
||||||
}}
|
|
||||||
if err := SendCtx(ctx, reqs, req); err != nil {
|
|
||||||
logger.Debug(ctx, "v1ReqLoop ctx expired", slog.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logger, q Queue, resps <-chan *proto.CoordinateResponse) {
|
|
||||||
defer func() {
|
defer func() {
|
||||||
cErr := q.Close()
|
cErr := q.Close()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
@ -1121,13 +968,13 @@ func v1RespLoop(ctx context.Context, cancel context.CancelFunc, logger slog.Logg
|
|||||||
for {
|
for {
|
||||||
resp, err := RecvCtx(ctx, resps)
|
resp, err := RecvCtx(ctx, resps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug(ctx, "v1RespLoop done reading responses", slog.Error(err))
|
logger.Debug(ctx, "qRespLoop done reading responses", slog.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Debug(ctx, "v1RespLoop got response", slog.F("resp", resp))
|
logger.Debug(ctx, "qRespLoop got response", slog.F("resp", resp))
|
||||||
err = q.Enqueue(resp)
|
err = q.Enqueue(resp)
|
||||||
if err != nil && !xerrors.Is(err, context.Canceled) {
|
if err != nil && !xerrors.Is(err, context.Canceled) {
|
||||||
logger.Error(ctx, "v1RespLoop failed to enqueue v1 update", slog.Error(err))
|
logger.Error(ctx, "qRespLoop failed to enqueue v1 update", slog.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,7 @@ package tailnet_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -13,10 +10,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/mock/gomock"
|
"go.uber.org/mock/gomock"
|
||||||
"nhooyr.io/websocket"
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
|
||||||
@ -34,162 +29,107 @@ func TestCoordinator(t *testing.T) {
|
|||||||
t.Run("ClientWithoutAgent", func(t *testing.T) {
|
t.Run("ClientWithoutAgent", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := coordinator.Close()
|
err := coordinator.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
client, server := net.Pipe()
|
|
||||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
client := test.NewClient(ctx, t, coordinator, "client", uuid.New())
|
||||||
return nil
|
defer client.Close(ctx)
|
||||||
})
|
client.UpdateNode(&proto.Node{
|
||||||
id := uuid.New()
|
Addresses: []string{netip.PrefixFrom(tailnet.IP(), 128).String()},
|
||||||
closeChan := make(chan struct{})
|
PreferredDerp: 10,
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeClient(server, id, uuid.New())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeChan)
|
|
||||||
}()
|
|
||||||
sendNode(&tailnet.Node{
|
|
||||||
Addresses: []netip.Prefix{
|
|
||||||
netip.PrefixFrom(tailnet.IP(), 128),
|
|
||||||
},
|
|
||||||
PreferredDERP: 10,
|
|
||||||
})
|
})
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
return coordinator.Node(id) != nil
|
return coordinator.Node(client.ID) != nil
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
require.NoError(t, client.Close())
|
|
||||||
require.NoError(t, server.Close())
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) {
|
t.Run("ClientWithoutAgent_InvalidIPBits", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := coordinator.Close()
|
err := coordinator.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
client, server := net.Pipe()
|
|
||||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
id := uuid.New()
|
|
||||||
closeChan := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeClient(server, id, uuid.New())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeChan)
|
|
||||||
}()
|
|
||||||
sendNode(&tailnet.Node{
|
|
||||||
Addresses: []netip.Prefix{
|
|
||||||
netip.PrefixFrom(tailnet.IP(), 64),
|
|
||||||
},
|
|
||||||
PreferredDERP: 10,
|
|
||||||
})
|
|
||||||
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
client := test.NewClient(ctx, t, coordinator, "client", uuid.New())
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
defer client.Close(ctx)
|
||||||
|
|
||||||
|
client.UpdateNode(&proto.Node{
|
||||||
|
Addresses: []string{
|
||||||
|
netip.PrefixFrom(tailnet.IP(), 64).String(),
|
||||||
|
},
|
||||||
|
PreferredDerp: 10,
|
||||||
|
})
|
||||||
|
client.AssertEventuallyResponsesClosed()
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentWithoutClients", func(t *testing.T) {
|
t.Run("AgentWithoutClients", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := coordinator.Close()
|
err := coordinator.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
client, server := net.Pipe()
|
|
||||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
agent := test.NewAgent(ctx, t, coordinator, "agent")
|
||||||
return nil
|
defer agent.Close(ctx)
|
||||||
})
|
agent.UpdateNode(&proto.Node{
|
||||||
id := uuid.New()
|
Addresses: []string{
|
||||||
closeChan := make(chan struct{})
|
netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 128).String(),
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeAgent(server, id, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeChan)
|
|
||||||
}()
|
|
||||||
sendNode(&tailnet.Node{
|
|
||||||
Addresses: []netip.Prefix{
|
|
||||||
netip.PrefixFrom(tailnet.IPFromUUID(id), 128),
|
|
||||||
},
|
},
|
||||||
PreferredDERP: 10,
|
PreferredDerp: 10,
|
||||||
})
|
})
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
return coordinator.Node(id) != nil
|
return coordinator.Node(agent.ID) != nil
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
err := client.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) {
|
t.Run("AgentWithoutClients_InvalidIP", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := coordinator.Close()
|
err := coordinator.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
client, server := net.Pipe()
|
agent := test.NewAgent(ctx, t, coordinator, "agent")
|
||||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
defer agent.Close(ctx)
|
||||||
return nil
|
agent.UpdateNode(&proto.Node{
|
||||||
})
|
Addresses: []string{
|
||||||
id := uuid.New()
|
netip.PrefixFrom(tailnet.IP(), 128).String(),
|
||||||
closeChan := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeAgent(server, id, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeChan)
|
|
||||||
}()
|
|
||||||
sendNode(&tailnet.Node{
|
|
||||||
Addresses: []netip.Prefix{
|
|
||||||
netip.PrefixFrom(tailnet.IP(), 128),
|
|
||||||
},
|
},
|
||||||
PreferredDERP: 10,
|
PreferredDerp: 10,
|
||||||
})
|
})
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
agent.AssertEventuallyResponsesClosed()
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) {
|
t.Run("AgentWithoutClients_InvalidBits", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := coordinator.Close()
|
err := coordinator.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
client, server := net.Pipe()
|
agent := test.NewAgent(ctx, t, coordinator, "agent")
|
||||||
sendNode, errChan := tailnet.ServeCoordinator(client, func(node []*tailnet.Node) error {
|
defer agent.Close(ctx)
|
||||||
return nil
|
agent.UpdateNode(&proto.Node{
|
||||||
})
|
Addresses: []string{
|
||||||
id := uuid.New()
|
netip.PrefixFrom(tailnet.IPFromUUID(agent.ID), 64).String(),
|
||||||
closeChan := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeAgent(server, id, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeChan)
|
|
||||||
}()
|
|
||||||
sendNode(&tailnet.Node{
|
|
||||||
Addresses: []netip.Prefix{
|
|
||||||
netip.PrefixFrom(tailnet.IPFromUUID(id), 64),
|
|
||||||
},
|
},
|
||||||
PreferredDERP: 10,
|
PreferredDerp: 10,
|
||||||
})
|
})
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, errChan)
|
agent.AssertEventuallyResponsesClosed()
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeChan)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentWithClient", func(t *testing.T) {
|
t.Run("AgentWithClient", func(t *testing.T) {
|
||||||
@ -201,180 +141,71 @@ func TestCoordinator(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// in this test we use real websockets to test use of deadlines
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
agentWS, agentServerWS := websocketConn(ctx, t)
|
agent := test.NewAgent(ctx, t, coordinator, "agent")
|
||||||
defer agentWS.Close()
|
defer agent.Close(ctx)
|
||||||
agentNodeChan := make(chan []*tailnet.Node)
|
agent.UpdateDERP(1)
|
||||||
sendAgentNode, agentErrChan := tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error {
|
|
||||||
agentNodeChan <- nodes
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
agentID := uuid.New()
|
|
||||||
closeAgentChan := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeAgent(agentServerWS, agentID, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeAgentChan)
|
|
||||||
}()
|
|
||||||
sendAgentNode(&tailnet.Node{PreferredDERP: 1})
|
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
return coordinator.Node(agentID) != nil
|
return coordinator.Node(agent.ID) != nil
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
|
||||||
clientWS, clientServerWS := websocketConn(ctx, t)
|
client := test.NewClient(ctx, t, coordinator, "client", agent.ID)
|
||||||
defer clientWS.Close()
|
defer client.Close(ctx)
|
||||||
defer clientServerWS.Close()
|
client.AssertEventuallyHasDERP(agent.ID, 1)
|
||||||
clientNodeChan := make(chan []*tailnet.Node)
|
|
||||||
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
|
|
||||||
clientNodeChan <- nodes
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
clientID := uuid.New()
|
|
||||||
closeClientChan := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeClientChan)
|
|
||||||
}()
|
|
||||||
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
|
||||||
require.Len(t, agentNodes, 1)
|
|
||||||
|
|
||||||
sendClientNode(&tailnet.Node{PreferredDERP: 2})
|
client.UpdateDERP(2)
|
||||||
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan)
|
agent.AssertEventuallyHasDERP(client.ID, 2)
|
||||||
require.Len(t, clientNodes, 1)
|
|
||||||
|
|
||||||
// wait longer than the internal wait timeout.
|
|
||||||
// this tests for regression of https://github.com/coder/coder/issues/7428
|
|
||||||
time.Sleep(tailnet.WriteTimeout * 3 / 2)
|
|
||||||
|
|
||||||
// Ensure an update to the agent node reaches the client!
|
// Ensure an update to the agent node reaches the client!
|
||||||
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
|
agent.UpdateDERP(3)
|
||||||
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
client.AssertEventuallyHasDERP(agent.ID, 3)
|
||||||
require.Len(t, agentNodes, 1)
|
|
||||||
|
|
||||||
// Close the agent WebSocket so a new one can connect.
|
// Close the agent so a new one can connect.
|
||||||
err := agentWS.Close()
|
agent.Close(ctx)
|
||||||
require.NoError(t, err)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
|
|
||||||
|
|
||||||
// Create a new agent connection. This is to simulate a reconnect!
|
// Create a new agent connection. This is to simulate a reconnect!
|
||||||
agentWS, agentServerWS = net.Pipe()
|
agent = test.NewPeer(ctx, t, coordinator, "agent", test.WithID(agent.ID))
|
||||||
defer agentWS.Close()
|
defer agent.Close(ctx)
|
||||||
agentNodeChan = make(chan []*tailnet.Node)
|
// Ensure the agent gets the existing client node immediately!
|
||||||
_, agentErrChan = tailnet.ServeCoordinator(agentWS, func(nodes []*tailnet.Node) error {
|
agent.AssertEventuallyHasDERP(client.ID, 2)
|
||||||
agentNodeChan <- nodes
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
closeAgentChan = make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeAgent(agentServerWS, agentID, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeAgentChan)
|
|
||||||
}()
|
|
||||||
// Ensure the existing listening client sends its node immediately!
|
|
||||||
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan)
|
|
||||||
require.Len(t, clientNodes, 1)
|
|
||||||
|
|
||||||
err = agentWS.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan)
|
|
||||||
|
|
||||||
err = clientWS.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentDoubleConnect", func(t *testing.T) {
|
t.Run("AgentDoubleConnect", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
coordinator := tailnet.NewCoordinator(logger)
|
||||||
ctx := testutil.Context(t, testutil.WaitLong)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
|
||||||
agentWS1, agentServerWS1 := net.Pipe()
|
|
||||||
defer agentWS1.Close()
|
|
||||||
agentNodeChan1 := make(chan []*tailnet.Node)
|
|
||||||
sendAgentNode1, agentErrChan1 := tailnet.ServeCoordinator(agentWS1, func(nodes []*tailnet.Node) error {
|
|
||||||
t.Logf("agent1 got node update: %v", nodes)
|
|
||||||
agentNodeChan1 <- nodes
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
agentID := uuid.New()
|
agentID := uuid.New()
|
||||||
closeAgentChan1 := make(chan struct{})
|
agent1 := test.NewPeer(ctx, t, coordinator, "agent1", test.WithID(agentID))
|
||||||
go func() {
|
defer agent1.Close(ctx)
|
||||||
err := coordinator.ServeAgent(agentServerWS1, agentID, "")
|
agent1.UpdateDERP(1)
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeAgentChan1)
|
|
||||||
}()
|
|
||||||
sendAgentNode1(&tailnet.Node{PreferredDERP: 1})
|
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
return coordinator.Node(agentID) != nil
|
return coordinator.Node(agentID) != nil
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
|
||||||
clientWS, clientServerWS := net.Pipe()
|
client := test.NewPeer(ctx, t, coordinator, "client")
|
||||||
defer clientWS.Close()
|
defer client.Close(ctx)
|
||||||
defer clientServerWS.Close()
|
client.AddTunnel(agentID)
|
||||||
clientNodeChan := make(chan []*tailnet.Node)
|
client.AssertEventuallyHasDERP(agent1.ID, 1)
|
||||||
sendClientNode, clientErrChan := tailnet.ServeCoordinator(clientWS, func(nodes []*tailnet.Node) error {
|
|
||||||
t.Logf("client got node update: %v", nodes)
|
client.UpdateDERP(2)
|
||||||
clientNodeChan <- nodes
|
agent1.AssertEventuallyHasDERP(client.ID, 2)
|
||||||
return nil
|
|
||||||
})
|
|
||||||
clientID := uuid.New()
|
|
||||||
closeClientChan := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeClientChan)
|
|
||||||
}()
|
|
||||||
agentNodes := testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
|
||||||
require.Len(t, agentNodes, 1)
|
|
||||||
sendClientNode(&tailnet.Node{PreferredDERP: 2})
|
|
||||||
clientNodes := testutil.RequireRecvCtx(ctx, t, agentNodeChan1)
|
|
||||||
require.Len(t, clientNodes, 1)
|
|
||||||
|
|
||||||
// Ensure an update to the agent node reaches the client!
|
// Ensure an update to the agent node reaches the client!
|
||||||
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
|
agent1.UpdateDERP(3)
|
||||||
agentNodes = testutil.RequireRecvCtx(ctx, t, clientNodeChan)
|
client.AssertEventuallyHasDERP(agent1.ID, 3)
|
||||||
require.Len(t, agentNodes, 1)
|
|
||||||
|
|
||||||
// Create a new agent connection without disconnecting the old one.
|
// Create a new agent connection without disconnecting the old one.
|
||||||
agentWS2, agentServerWS2 := net.Pipe()
|
agent2 := test.NewPeer(ctx, t, coordinator, "agent2", test.WithID(agentID))
|
||||||
defer agentWS2.Close()
|
defer agent2.Close(ctx)
|
||||||
agentNodeChan2 := make(chan []*tailnet.Node)
|
|
||||||
_, agentErrChan2 := tailnet.ServeCoordinator(agentWS2, func(nodes []*tailnet.Node) error {
|
|
||||||
t.Logf("agent2 got node update: %v", nodes)
|
|
||||||
agentNodeChan2 <- nodes
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
closeAgentChan2 := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeAgent(agentServerWS2, agentID, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
close(closeAgentChan2)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Ensure the existing listening client sends it's node immediately!
|
// Ensure the existing client node gets sent immediately!
|
||||||
clientNodes = testutil.RequireRecvCtx(ctx, t, agentNodeChan2)
|
agent2.AssertEventuallyHasDERP(client.ID, 2)
|
||||||
require.Len(t, clientNodes, 1)
|
|
||||||
|
|
||||||
// This original agent websocket should've been closed forcefully.
|
// This original agent channels should've been closed forcefully.
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan1)
|
agent1.AssertEventuallyResponsesClosed()
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan1)
|
|
||||||
|
|
||||||
err := agentWS2.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, agentErrChan2)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeAgentChan2)
|
|
||||||
|
|
||||||
err = clientWS.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, clientErrChan)
|
|
||||||
_ = testutil.RequireRecvCtx(ctx, t, closeClientChan)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AgentAck", func(t *testing.T) {
|
t.Run("AgentAck", func(t *testing.T) {
|
||||||
@ -396,89 +227,6 @@ func TestCoordinator(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestCoordinator_AgentUpdateWhileClientConnects tests for regression on
|
|
||||||
// https://github.com/coder/coder/issues/7295
|
|
||||||
func TestCoordinator_AgentUpdateWhileClientConnects(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
||||||
coordinator := tailnet.NewCoordinator(logger)
|
|
||||||
agentWS, agentServerWS := net.Pipe()
|
|
||||||
defer agentWS.Close()
|
|
||||||
|
|
||||||
agentID := uuid.New()
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeAgent(agentServerWS, agentID, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// send an agent update before the client connects so that there is
|
|
||||||
// node data available to send right away.
|
|
||||||
aNode := tailnet.Node{PreferredDERP: 0}
|
|
||||||
aData, err := json.Marshal(&aNode)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = agentWS.Write(aData)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Eventually(t, func() bool {
|
|
||||||
return coordinator.Node(agentID) != nil
|
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
|
||||||
|
|
||||||
// Connect from the client
|
|
||||||
clientWS, clientServerWS := net.Pipe()
|
|
||||||
defer clientWS.Close()
|
|
||||||
clientID := uuid.New()
|
|
||||||
go func() {
|
|
||||||
err := coordinator.ServeClient(clientServerWS, clientID, agentID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// peek one byte from the node update, so we know the coordinator is
|
|
||||||
// trying to write to the client.
|
|
||||||
// buffer needs to be 2 characters longer because return value is a list
|
|
||||||
// so, it needs [ and ]
|
|
||||||
buf := make([]byte, len(aData)+2)
|
|
||||||
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
|
|
||||||
require.NoError(t, err)
|
|
||||||
n, err := clientWS.Read(buf[:1])
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, n)
|
|
||||||
|
|
||||||
// send a second update
|
|
||||||
aNode.PreferredDERP = 1
|
|
||||||
require.NoError(t, err)
|
|
||||||
aData, err = json.Marshal(&aNode)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = agentWS.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = agentWS.Write(aData)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// read the rest of the update from the client, should be initial node.
|
|
||||||
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
|
|
||||||
require.NoError(t, err)
|
|
||||||
n, err = clientWS.Read(buf[1:])
|
|
||||||
require.NoError(t, err)
|
|
||||||
var cNodes []*tailnet.Node
|
|
||||||
err = json.Unmarshal(buf[:n+1], &cNodes)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, cNodes, 1)
|
|
||||||
require.Equal(t, 0, cNodes[0].PreferredDERP)
|
|
||||||
|
|
||||||
// read second update
|
|
||||||
// without a fix for https://github.com/coder/coder/issues/7295 our
|
|
||||||
// read would time out here.
|
|
||||||
err = clientWS.SetReadDeadline(time.Now().Add(testutil.WaitShort))
|
|
||||||
require.NoError(t, err)
|
|
||||||
n, err = clientWS.Read(buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = json.Unmarshal(buf[:n], &cNodes)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, cNodes, 1)
|
|
||||||
require.Equal(t, 1, cNodes[0].PreferredDERP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCoordinator_BidirectionalTunnels(t *testing.T) {
|
func TestCoordinator_BidirectionalTunnels(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
@ -521,29 +269,6 @@ func TestCoordinator_MultiAgent_CoordClose(t *testing.T) {
|
|||||||
ma1.RequireEventuallyClosed(ctx)
|
ma1.RequireEventuallyClosed(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func websocketConn(ctx context.Context, t *testing.T) (client net.Conn, server net.Conn) {
|
|
||||||
t.Helper()
|
|
||||||
sc := make(chan net.Conn, 1)
|
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
||||||
wss, err := websocket.Accept(rw, r, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
conn := websocket.NetConn(r.Context(), wss, websocket.MessageBinary)
|
|
||||||
sc <- conn
|
|
||||||
close(sc) // there can be only one
|
|
||||||
|
|
||||||
// hold open until context canceled
|
|
||||||
<-ctx.Done()
|
|
||||||
}))
|
|
||||||
t.Cleanup(s.Close)
|
|
||||||
// nolint: bodyclose
|
|
||||||
wsc, _, err := websocket.Dial(ctx, s.URL, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
client = websocket.NetConn(ctx, wsc, websocket.MessageBinary)
|
|
||||||
server, ok := <-sc
|
|
||||||
require.True(t, ok)
|
|
||||||
return client, server
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInMemoryCoordination(t *testing.T) {
|
func TestInMemoryCoordination(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx := testutil.Context(t, testutil.WaitShort)
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
@ -9,4 +9,4 @@ const (
|
|||||||
CurrentMinor = 2
|
CurrentMinor = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)
|
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
|
||||||
|
@ -21,6 +21,8 @@ import (
|
|||||||
"github.com/coder/quartz"
|
"github.com/coder/quartz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrUnsupportedVersion = xerrors.New("unsupported version")
|
||||||
|
|
||||||
type streamIDContextKey struct{}
|
type streamIDContextKey struct{}
|
||||||
|
|
||||||
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
||||||
@ -47,7 +49,7 @@ type ClientServiceOptions struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ClientService is a tailnet coordination service that accepts a connection and version from a
|
// ClientService is a tailnet coordination service that accepts a connection and version from a
|
||||||
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
|
// tailnet client, and support versions 2.x of the Tailnet API protocol.
|
||||||
type ClientService struct {
|
type ClientService struct {
|
||||||
Logger slog.Logger
|
Logger slog.Logger
|
||||||
CoordPtr *atomic.Pointer[Coordinator]
|
CoordPtr *atomic.Pointer[Coordinator]
|
||||||
@ -94,9 +96,6 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
switch major {
|
switch major {
|
||||||
case 1:
|
|
||||||
coord := *(s.CoordPtr.Load())
|
|
||||||
return coord.ServeClient(conn, id, agent)
|
|
||||||
case 2:
|
case 2:
|
||||||
auth := ClientCoordinateeAuth{AgentID: agent}
|
auth := ClientCoordinateeAuth{AgentID: agent}
|
||||||
streamID := StreamID{
|
streamID := StreamID{
|
||||||
@ -107,7 +106,7 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne
|
|||||||
return s.ServeConnV2(ctx, conn, streamID)
|
return s.ServeConnV2(ctx, conn, streamID)
|
||||||
default:
|
default:
|
||||||
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
||||||
return xerrors.New("unsupported version")
|
return ErrUnsupportedVersion
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,21 +162,8 @@ func TestClientService_ServeClient_V1(t *testing.T) {
|
|||||||
errCh <- err
|
errCh <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
call := testutil.RequireRecvCtx(ctx, t, fCoord.ServeClientCalls)
|
|
||||||
require.NotNil(t, call)
|
|
||||||
require.Equal(t, call.ID, clientID)
|
|
||||||
require.Equal(t, call.Agent, agentID)
|
|
||||||
require.Equal(t, s, call.Conn)
|
|
||||||
expectedError := xerrors.New("test error")
|
|
||||||
select {
|
|
||||||
case call.ErrCh <- expectedError:
|
|
||||||
// ok!
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatalf("timeout sending error")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||||
require.ErrorIs(t, err, expectedError)
|
require.ErrorIs(t, err, tailnet.ErrUnsupportedVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNetworkTelemetryBatcher(t *testing.T) {
|
func TestNetworkTelemetryBatcher(t *testing.T) {
|
||||||
|
@ -11,7 +11,6 @@ package tailnettest
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
net "net"
|
|
||||||
http "net/http"
|
http "net/http"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
@ -87,34 +86,6 @@ func (mr *MockCoordinatorMockRecorder) Node(arg0 any) *gomock.Call {
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Node", reflect.TypeOf((*MockCoordinator)(nil).Node), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeAgent mocks base method.
|
|
||||||
func (m *MockCoordinator) ServeAgent(arg0 net.Conn, arg1 uuid.UUID, arg2 string) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "ServeAgent", arg0, arg1, arg2)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeAgent indicates an expected call of ServeAgent.
|
|
||||||
func (mr *MockCoordinatorMockRecorder) ServeAgent(arg0, arg1, arg2 any) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeAgent", reflect.TypeOf((*MockCoordinator)(nil).ServeAgent), arg0, arg1, arg2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeClient mocks base method.
|
|
||||||
func (m *MockCoordinator) ServeClient(arg0 net.Conn, arg1, arg2 uuid.UUID) error {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "ServeClient", arg0, arg1, arg2)
|
|
||||||
ret0, _ := ret[0].(error)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeClient indicates an expected call of ServeClient.
|
|
||||||
func (mr *MockCoordinatorMockRecorder) ServeClient(arg0, arg1, arg2 any) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeClient", reflect.TypeOf((*MockCoordinator)(nil).ServeClient), arg0, arg1, arg2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTPDebug mocks base method.
|
// ServeHTTPDebug mocks base method.
|
||||||
func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) {
|
func (m *MockCoordinator) ServeHTTPDebug(arg0 http.ResponseWriter, arg1 *http.Request) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -161,7 +161,7 @@ func RunDERPOnlyWebSockets(t *testing.T) *tailcfg.DERPMap {
|
|||||||
|
|
||||||
type TestMultiAgent struct {
|
type TestMultiAgent struct {
|
||||||
t testing.TB
|
t testing.TB
|
||||||
id uuid.UUID
|
ID uuid.UUID
|
||||||
a tailnet.MultiAgentConn
|
a tailnet.MultiAgentConn
|
||||||
nodeKey []byte
|
nodeKey []byte
|
||||||
discoKey string
|
discoKey string
|
||||||
@ -172,8 +172,8 @@ func NewTestMultiAgent(t testing.TB, coord tailnet.Coordinator) *TestMultiAgent
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
dk, err := key.NewDisco().Public().MarshalText()
|
dk, err := key.NewDisco().Public().MarshalText()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
m := &TestMultiAgent{t: t, id: uuid.New(), nodeKey: nk, discoKey: string(dk)}
|
m := &TestMultiAgent{t: t, ID: uuid.New(), nodeKey: nk, discoKey: string(dk)}
|
||||||
m.a = coord.ServeMultiAgent(m.id)
|
m.a = coord.ServeMultiAgent(m.ID)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,7 +278,6 @@ func (m *TestMultiAgent) RequireEventuallyClosed(ctx context.Context) {
|
|||||||
|
|
||||||
type FakeCoordinator struct {
|
type FakeCoordinator struct {
|
||||||
CoordinateCalls chan *FakeCoordinate
|
CoordinateCalls chan *FakeCoordinate
|
||||||
ServeClientCalls chan *FakeServeClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*FakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) {
|
func (*FakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) {
|
||||||
@ -289,21 +288,6 @@ func (*FakeCoordinator) Node(uuid.UUID) *tailnet.Node {
|
|||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *FakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
|
||||||
errCh := make(chan error)
|
|
||||||
f.ServeClientCalls <- &FakeServeClient{
|
|
||||||
Conn: conn,
|
|
||||||
ID: id,
|
|
||||||
Agent: agent,
|
|
||||||
ErrCh: errCh,
|
|
||||||
}
|
|
||||||
return <-errCh
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*FakeCoordinator) ServeAgent(net.Conn, uuid.UUID, string) error {
|
|
||||||
panic("unimplemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*FakeCoordinator) Close() error {
|
func (*FakeCoordinator) Close() error {
|
||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
}
|
}
|
||||||
@ -329,7 +313,6 @@ func (f *FakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name str
|
|||||||
func NewFakeCoordinator() *FakeCoordinator {
|
func NewFakeCoordinator() *FakeCoordinator {
|
||||||
return &FakeCoordinator{
|
return &FakeCoordinator{
|
||||||
CoordinateCalls: make(chan *FakeCoordinate, 100),
|
CoordinateCalls: make(chan *FakeCoordinate, 100),
|
||||||
ServeClientCalls: make(chan *FakeServeClient, 100),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,10 +3,12 @@ package test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
@ -18,43 +20,73 @@ type PeerStatus struct {
|
|||||||
readyForHandshake bool
|
readyForHandshake bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PeerOption func(*Peer)
|
||||||
|
|
||||||
|
func WithID(id uuid.UUID) PeerOption {
|
||||||
|
return func(p *Peer) {
|
||||||
|
p.ID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAuth(auth tailnet.CoordinateeAuth) PeerOption {
|
||||||
|
return func(p *Peer) {
|
||||||
|
p.auth = auth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
t testing.TB
|
t testing.TB
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
|
auth tailnet.CoordinateeAuth
|
||||||
name string
|
name string
|
||||||
|
nodeKey key.NodePublic
|
||||||
|
discoKey key.DiscoPublic
|
||||||
resps <-chan *proto.CoordinateResponse
|
resps <-chan *proto.CoordinateResponse
|
||||||
reqs chan<- *proto.CoordinateRequest
|
reqs chan<- *proto.CoordinateRequest
|
||||||
peers map[uuid.UUID]PeerStatus
|
peers map[uuid.UUID]PeerStatus
|
||||||
peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate
|
peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, id ...uuid.UUID) *Peer {
|
func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, opts ...PeerOption) *Peer {
|
||||||
p := &Peer{
|
p := &Peer{
|
||||||
t: t,
|
t: t,
|
||||||
name: name,
|
name: name,
|
||||||
peers: make(map[uuid.UUID]PeerStatus),
|
peers: make(map[uuid.UUID]PeerStatus),
|
||||||
peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate),
|
peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate),
|
||||||
|
ID: uuid.New(),
|
||||||
|
// SingleTailnetCoordinateeAuth allows connections to arbitrary peers
|
||||||
|
auth: tailnet.SingleTailnetCoordinateeAuth{},
|
||||||
|
// required for converting to and from protobuf, so we always include them
|
||||||
|
nodeKey: key.NewNode().Public(),
|
||||||
|
discoKey: key.NewDisco().Public(),
|
||||||
}
|
}
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
if len(id) > 1 {
|
for _, opt := range opts {
|
||||||
t.Fatal("too many")
|
opt(p)
|
||||||
}
|
}
|
||||||
if len(id) == 1 {
|
|
||||||
p.ID = id[0]
|
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, p.auth)
|
||||||
} else {
|
return p
|
||||||
p.ID = uuid.New()
|
|
||||||
}
|
}
|
||||||
// SingleTailnetTunnelAuth allows connections to arbitrary peers
|
|
||||||
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, tailnet.SingleTailnetCoordinateeAuth{})
|
// NewAgent is a wrapper around NewPeer, creating a peer with Agent auth tied to its ID
|
||||||
|
func NewAgent(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string) *Peer {
|
||||||
|
id := uuid.New()
|
||||||
|
return NewPeer(ctx, t, coord, name, WithID(id), WithAuth(tailnet.AgentCoordinateeAuth{ID: id}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient is a wrapper around NewPeer, creating a peer with Client auth tied to the provided agentID
|
||||||
|
func NewClient(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, agentID uuid.UUID) *Peer {
|
||||||
|
p := NewPeer(ctx, t, coord, name, WithAuth(tailnet.ClientCoordinateeAuth{AgentID: agentID}))
|
||||||
|
p.AddTunnel(agentID)
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) {
|
func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) {
|
||||||
p.t.Helper()
|
p.t.Helper()
|
||||||
|
p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, p.auth)
|
||||||
p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, tailnet.SingleTailnetCoordinateeAuth{})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) AddTunnel(other uuid.UUID) {
|
func (p *Peer) AddTunnel(other uuid.UUID) {
|
||||||
@ -71,7 +103,19 @@ func (p *Peer) AddTunnel(other uuid.UUID) {
|
|||||||
|
|
||||||
func (p *Peer) UpdateDERP(derp int32) {
|
func (p *Peer) UpdateDERP(derp int32) {
|
||||||
p.t.Helper()
|
p.t.Helper()
|
||||||
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: derp}}}
|
node := &proto.Node{PreferredDerp: derp}
|
||||||
|
p.UpdateNode(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) UpdateNode(node *proto.Node) {
|
||||||
|
p.t.Helper()
|
||||||
|
nk, err := p.nodeKey.MarshalBinary()
|
||||||
|
assert.NoError(p.t, err)
|
||||||
|
node.Key = nk
|
||||||
|
dk, err := p.discoKey.MarshalText()
|
||||||
|
assert.NoError(p.t, err)
|
||||||
|
node.Disco = string(dk)
|
||||||
|
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}}
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
p.t.Errorf("timeout updating node for %s", p.name)
|
p.t.Errorf("timeout updating node for %s", p.name)
|
||||||
@ -115,13 +159,37 @@ func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) {
|
|||||||
if ok && o.preferredDERP == derp {
|
if ok && o.preferredDERP == derp {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := p.handleOneResp(); err != nil {
|
if err := p.readOneResp(); err != nil {
|
||||||
assert.NoError(p.t, err)
|
assert.NoError(p.t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertNeverHasDERPs(ctx context.Context, other uuid.UUID, expected ...int32) {
|
||||||
|
p.t.Helper()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case resp, ok := <-p.resps:
|
||||||
|
if !ok {
|
||||||
|
p.t.Errorf("response channel closed")
|
||||||
|
}
|
||||||
|
if !assert.NoError(p.t, p.handleResp(resp)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
derp, ok := p.peers[other]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !assert.NotContains(p.t, expected, derp) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
|
func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
|
||||||
p.t.Helper()
|
p.t.Helper()
|
||||||
for {
|
for {
|
||||||
@ -129,7 +197,7 @@ func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := p.handleOneResp(); err != nil {
|
if err := p.readOneResp(); err != nil {
|
||||||
assert.NoError(p.t, err)
|
assert.NoError(p.t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -143,7 +211,7 @@ func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
|
|||||||
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
|
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := p.handleOneResp(); err != nil {
|
if err := p.readOneResp(); err != nil {
|
||||||
assert.NoError(p.t, err)
|
assert.NoError(p.t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -153,7 +221,7 @@ func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
|
|||||||
func (p *Peer) AssertEventuallyResponsesClosed() {
|
func (p *Peer) AssertEventuallyResponsesClosed() {
|
||||||
p.t.Helper()
|
p.t.Helper()
|
||||||
for {
|
for {
|
||||||
err := p.handleOneResp()
|
err := p.readOneResp()
|
||||||
if xerrors.Is(err, responsesClosed) {
|
if xerrors.Is(err, responsesClosed) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -163,6 +231,32 @@ func (p *Peer) AssertEventuallyResponsesClosed() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) AssertNotClosed(d time.Duration) {
|
||||||
|
p.t.Helper()
|
||||||
|
// nolint: gocritic // erroneously thinks we're hardcoding non testutil constants here
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), d)
|
||||||
|
defer cancel()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// success!
|
||||||
|
return
|
||||||
|
case <-p.ctx.Done():
|
||||||
|
p.t.Error("main ctx timeout before elapsed time")
|
||||||
|
return
|
||||||
|
case resp, ok := <-p.resps:
|
||||||
|
if !ok {
|
||||||
|
p.t.Error("response channel closed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := p.handleResp(resp)
|
||||||
|
if !assert.NoError(p.t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
|
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
|
||||||
p.t.Helper()
|
p.t.Helper()
|
||||||
for {
|
for {
|
||||||
@ -171,7 +265,7 @@ func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := p.handleOneResp()
|
err := p.readOneResp()
|
||||||
if xerrors.Is(err, responsesClosed) {
|
if xerrors.Is(err, responsesClosed) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -181,8 +275,9 @@ func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
|
|||||||
func (p *Peer) AssertEventuallyGetsError(match string) {
|
func (p *Peer) AssertEventuallyGetsError(match string) {
|
||||||
p.t.Helper()
|
p.t.Helper()
|
||||||
for {
|
for {
|
||||||
err := p.handleOneResp()
|
err := p.readOneResp()
|
||||||
if xerrors.Is(err, responsesClosed) {
|
if xerrors.Is(err, responsesClosed) {
|
||||||
|
p.t.Error("closed before target error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,7 +302,7 @@ func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateRespon
|
|||||||
|
|
||||||
var responsesClosed = xerrors.New("responses closed")
|
var responsesClosed = xerrors.New("responses closed")
|
||||||
|
|
||||||
func (p *Peer) handleOneResp() error {
|
func (p *Peer) readOneResp() error {
|
||||||
select {
|
select {
|
||||||
case <-p.ctx.Done():
|
case <-p.ctx.Done():
|
||||||
return p.ctx.Err()
|
return p.ctx.Err()
|
||||||
@ -215,6 +310,15 @@ func (p *Peer) handleOneResp() error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return responsesClosed
|
return responsesClosed
|
||||||
}
|
}
|
||||||
|
err := p.handleResp(resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) handleResp(resp *proto.CoordinateResponse) error {
|
||||||
if resp.Error != "" {
|
if resp.Error != "" {
|
||||||
return xerrors.New(resp.Error)
|
return xerrors.New(resp.Error)
|
||||||
}
|
}
|
||||||
@ -241,7 +345,6 @@ func (p *Peer) handleOneResp() error {
|
|||||||
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,3 +364,9 @@ func (p *Peer) Close(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Peer) UngracefulDisconnect(ctx context.Context) {
|
||||||
|
p.t.Helper()
|
||||||
|
close(p.reqs)
|
||||||
|
p.Close(ctx)
|
||||||
|
}
|
||||||
|
@ -1,182 +0,0 @@
|
|||||||
package tailnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"net"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
|
|
||||||
"cdr.dev/slog"
|
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// WriteTimeout is the amount of time we wait to write a node update to a connection before we
|
|
||||||
// declare it hung. It is exported so that tests can use it.
|
|
||||||
WriteTimeout = time.Second * 5
|
|
||||||
// ResponseBufferSize is the max number of responses to buffer per connection before we start
|
|
||||||
// dropping updates
|
|
||||||
ResponseBufferSize = 512
|
|
||||||
// RequestBufferSize is the max number of requests to buffer per connection
|
|
||||||
RequestBufferSize = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
type TrackedConn struct {
|
|
||||||
ctx context.Context
|
|
||||||
cancel func()
|
|
||||||
kind QueueKind
|
|
||||||
conn net.Conn
|
|
||||||
updates chan *proto.CoordinateResponse
|
|
||||||
logger slog.Logger
|
|
||||||
lastData []byte
|
|
||||||
|
|
||||||
// ID is an ephemeral UUID used to uniquely identify the owner of the
|
|
||||||
// connection.
|
|
||||||
id uuid.UUID
|
|
||||||
|
|
||||||
name string
|
|
||||||
start int64
|
|
||||||
lastWrite int64
|
|
||||||
overwrites int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTrackedConn(ctx context.Context, cancel func(),
|
|
||||||
conn net.Conn,
|
|
||||||
id uuid.UUID,
|
|
||||||
logger slog.Logger,
|
|
||||||
name string,
|
|
||||||
overwrites int64,
|
|
||||||
kind QueueKind,
|
|
||||||
) *TrackedConn {
|
|
||||||
// buffer updates so they don't block, since we hold the
|
|
||||||
// coordinator mutex while queuing. Node updates don't
|
|
||||||
// come quickly, so 512 should be plenty for all but
|
|
||||||
// the most pathological cases.
|
|
||||||
updates := make(chan *proto.CoordinateResponse, ResponseBufferSize)
|
|
||||||
now := time.Now().Unix()
|
|
||||||
return &TrackedConn{
|
|
||||||
ctx: ctx,
|
|
||||||
conn: conn,
|
|
||||||
cancel: cancel,
|
|
||||||
updates: updates,
|
|
||||||
logger: logger,
|
|
||||||
id: id,
|
|
||||||
start: now,
|
|
||||||
lastWrite: now,
|
|
||||||
name: name,
|
|
||||||
overwrites: overwrites,
|
|
||||||
kind: kind,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) Enqueue(resp *proto.CoordinateResponse) (err error) {
|
|
||||||
atomic.StoreInt64(&t.lastWrite, time.Now().Unix())
|
|
||||||
select {
|
|
||||||
case t.updates <- resp:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return ErrWouldBlock
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) UniqueID() uuid.UUID {
|
|
||||||
return t.id
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) Kind() QueueKind {
|
|
||||||
return t.kind
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) Name() string {
|
|
||||||
return t.name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) Stats() (start, lastWrite int64) {
|
|
||||||
return t.start, atomic.LoadInt64(&t.lastWrite)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) Overwrites() int64 {
|
|
||||||
return t.overwrites
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) CoordinatorClose() error {
|
|
||||||
return t.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TrackedConn) Done() <-chan struct{} {
|
|
||||||
return t.ctx.Done()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the connection and cancel the context for reading node updates from the queue
|
|
||||||
func (t *TrackedConn) Close() error {
|
|
||||||
t.cancel()
|
|
||||||
return t.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendUpdates reads node updates and writes them to the connection. Ends when writes hit an error or context is
|
|
||||||
// canceled.
|
|
||||||
func (t *TrackedConn) SendUpdates() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-t.ctx.Done():
|
|
||||||
t.logger.Debug(t.ctx, "done sending updates")
|
|
||||||
return
|
|
||||||
case resp := <-t.updates:
|
|
||||||
nodes, err := OnlyNodeUpdates(resp)
|
|
||||||
if err != nil {
|
|
||||||
t.logger.Critical(t.ctx, "unable to parse response", slog.Error(err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(nodes) == 0 {
|
|
||||||
t.logger.Debug(t.ctx, "skipping response with no nodes")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(nodes)
|
|
||||||
if err != nil {
|
|
||||||
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if bytes.Equal(t.lastData, data) {
|
|
||||||
t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", string(data)))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a deadline so that hung connections don't put back pressure on the system.
|
|
||||||
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
|
|
||||||
err = t.conn.SetWriteDeadline(time.Now().Add(WriteTimeout))
|
|
||||||
if err != nil {
|
|
||||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
|
||||||
t.logger.Debug(t.ctx, "unable to set write deadline", slog.Error(err))
|
|
||||||
_ = t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = t.conn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
|
||||||
t.logger.Debug(t.ctx, "could not write nodes to connection",
|
|
||||||
slog.Error(err), slog.F("nodes", string(data)))
|
|
||||||
_ = t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.logger.Debug(t.ctx, "wrote nodes", slog.F("nodes", string(data)))
|
|
||||||
|
|
||||||
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
|
|
||||||
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
|
|
||||||
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
|
|
||||||
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
|
|
||||||
// our successful write, it is important that we reset the deadline before it fires.
|
|
||||||
err = t.conn.SetWriteDeadline(time.Time{})
|
|
||||||
if err != nil {
|
|
||||||
// often, this is just because the connection is closed/broken, so only log at debug.
|
|
||||||
t.logger.Debug(t.ctx, "unable to extend write deadline", slog.Error(err))
|
|
||||||
_ = t.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.lastData = data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Reference in New Issue
Block a user