Files
coder/vpn/tunnel_internal_test.go
Ethan edd982e852 fix(vpn/tunnel): fix panic when starting tunnel with headers (#16565)
`http.Header` is a map so it needs to be initialized ..
2025-02-14 03:11:51 +00:00

512 lines
14 KiB
Go

package vpn
import (
"context"
"net"
"net/netip"
"net/url"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/util/dnsname"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func newFakeClient(ctx context.Context, t *testing.T) *fakeClient {
return &fakeClient{
t: t,
ctx: ctx,
ch: make(chan *fakeConn, 1),
}
}
type fakeClient struct {
t *testing.T
ctx context.Context
ch chan *fakeConn
}
var _ Client = (*fakeClient)(nil)
func (f *fakeClient) NewConn(context.Context, *url.URL, string, *Options) (Conn, error) {
select {
case <-f.ctx.Done():
return nil, f.ctx.Err()
case conn := <-f.ch:
return conn, nil
}
}
func newFakeConn(state tailnet.WorkspaceUpdate, hsTime time.Time) *fakeConn {
return &fakeConn{
closed: make(chan struct{}),
state: state,
hsTime: hsTime,
}
}
type fakeConn struct {
state tailnet.WorkspaceUpdate
hsTime time.Time
closed chan struct{}
doClose sync.Once
}
var _ Conn = (*fakeConn)(nil)
func (f *fakeConn) CurrentWorkspaceState() (tailnet.WorkspaceUpdate, error) {
return f.state, nil
}
func (f *fakeConn) GetPeerDiagnostics(uuid.UUID) tailnet.PeerDiagnostics {
return tailnet.PeerDiagnostics{
LastWireguardHandshake: f.hsTime,
}
}
func (f *fakeConn) Close() error {
f.doClose.Do(func() {
close(f.closed)
})
return nil
}
func TestTunnel_StartStop(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := newFakeClient(ctx, t)
conn := newFakeConn(tailnet.WorkspaceUpdate{}, time.Time{})
_, mgr := setupTunnel(t, ctx, client, quartz.NewMock(t))
errCh := make(chan error, 1)
var resp *TunnelMessage
// When: we start the tunnel
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
TunnelFileDescriptor: 2,
CoderUrl: "https://coder.example.com",
ApiToken: "fakeToken",
Headers: []*StartRequest_Header{
{Name: "X-Test-Header", Value: "test"},
},
},
},
})
resp = r
errCh <- err
}()
// Then: `NewConn` is called,
testutil.RequireSendCtx(ctx, t, client.ch, conn)
// And: a response is received
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
_, ok := resp.Msg.(*TunnelMessage_Start)
require.True(t, ok)
// When: we stop the tunnel
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Stop{},
})
resp = r
errCh <- err
}()
// Then: `Close` is called on the connection
testutil.RequireRecvCtx(ctx, t, conn.closed)
// And: a Stop response is received
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
_, ok = resp.Msg.(*TunnelMessage_Stop)
require.True(t, ok)
err = mgr.Close()
require.NoError(t, err)
}
func TestTunnel_PeerUpdate(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
wsID1 := uuid.UUID{1}
wsID2 := uuid.UUID{2}
client := newFakeClient(ctx, t)
conn := newFakeConn(tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
{
ID: wsID1,
},
{
ID: wsID2,
},
},
}, time.Time{})
tun, mgr := setupTunnel(t, ctx, client, quartz.NewMock(t))
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
TunnelFileDescriptor: 2,
CoderUrl: "https://coder.example.com",
ApiToken: "fakeToken",
},
},
})
resp = r
errCh <- err
}()
testutil.RequireSendCtx(ctx, t, client.ch, conn)
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
_, ok := resp.Msg.(*TunnelMessage_Start)
require.True(t, ok)
// When: we inform the tunnel of a WorkspaceUpdate
err = tun.Update(tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
{
ID: wsID2,
},
},
})
require.NoError(t, err)
// Then: the tunnel sends a PeerUpdate message
req := testutil.RequireRecvCtx(ctx, t, mgr.requests)
require.Nil(t, req.msg.Rpc)
require.NotNil(t, req.msg.GetPeerUpdate())
require.Len(t, req.msg.GetPeerUpdate().UpsertedWorkspaces, 1)
require.Equal(t, wsID2[:], req.msg.GetPeerUpdate().UpsertedWorkspaces[0].Id)
// When: the manager requests a PeerUpdate
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_GetPeerUpdate{},
})
resp = r
errCh <- err
}()
// Then: a PeerUpdate message is sent using the Conn's state
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
_, ok = resp.Msg.(*TunnelMessage_PeerUpdate)
require.True(t, ok)
require.Len(t, resp.GetPeerUpdate().UpsertedWorkspaces, 2)
require.Equal(t, wsID1[:], resp.GetPeerUpdate().UpsertedWorkspaces[0].Id)
require.Equal(t, wsID2[:], resp.GetPeerUpdate().UpsertedWorkspaces[1].Id)
}
func TestTunnel_NetworkSettings(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := newFakeClient(ctx, t)
conn := newFakeConn(tailnet.WorkspaceUpdate{}, time.Time{})
tun, mgr := setupTunnel(t, ctx, client, quartz.NewMock(t))
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
TunnelFileDescriptor: 2,
CoderUrl: "https://coder.example.com",
ApiToken: "fakeToken",
},
},
})
resp = r
errCh <- err
}()
testutil.RequireSendCtx(ctx, t, client.ch, conn)
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
_, ok := resp.Msg.(*TunnelMessage_Start)
require.True(t, ok)
// When: we inform the tunnel of network settings
go func() {
err := tun.ApplyNetworkSettings(ctx, &NetworkSettingsRequest{
Mtu: 1200,
})
errCh <- err
}()
// Then: the tunnel sends a NetworkSettings message
req := testutil.RequireRecvCtx(ctx, t, mgr.requests)
require.NotNil(t, req.msg.Rpc)
require.Equal(t, uint32(1200), req.msg.GetNetworkSettings().Mtu)
go func() {
testutil.RequireSendCtx(ctx, t, mgr.sendCh, &ManagerMessage{
Rpc: &RPC{ResponseTo: req.msg.Rpc.MsgId},
Msg: &ManagerMessage_NetworkSettings{
NetworkSettings: &NetworkSettingsResponse{
Success: true,
},
},
})
}()
// And: `ApplyNetworkSettings` returns without error once the manager responds
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
}
func TestUpdater_createPeerUpdate(t *testing.T) {
t.Parallel()
w1ID := uuid.UUID{1}
w2ID := uuid.UUID{2}
w1a1ID := uuid.UUID{4}
w2a1ID := uuid.UUID{5}
w1a1IP := netip.MustParseAddr("fd60:627a:a42b:0101::")
w2a1IP := netip.MustParseAddr("fd60:627a:a42b:0301::")
ctx := testutil.Context(t, testutil.WaitShort)
hsTime := time.Now().Add(-time.Minute).UTC()
updater := updater{
ctx: ctx,
netLoopDone: make(chan struct{}),
agents: map[uuid.UUID]tailnet.Agent{},
conn: newFakeConn(tailnet.WorkspaceUpdate{}, hsTime),
}
update := updater.createPeerUpdateLocked(tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
{ID: w1ID, Name: "w1", Status: proto.Workspace_STARTING},
},
UpsertedAgents: []*tailnet.Agent{
{
ID: w1a1ID, Name: "w1a1", WorkspaceID: w1ID,
Hosts: map[dnsname.FQDN][]netip.Addr{
"w1.coder.": {w1a1IP},
"w1a1.w1.me.coder.": {w1a1IP},
"w1a1.w1.testy.coder.": {w1a1IP},
},
},
},
DeletedWorkspaces: []*tailnet.Workspace{
{ID: w2ID, Name: "w2", Status: proto.Workspace_STOPPED},
},
DeletedAgents: []*tailnet.Agent{
{
ID: w2a1ID, Name: "w2a1", WorkspaceID: w2ID,
Hosts: map[dnsname.FQDN][]netip.Addr{
"w2.coder.": {w2a1IP},
"w2a1.w2.me.coder.": {w2a1IP},
"w2a1.w2.testy.coder.": {w2a1IP},
},
},
},
})
require.Len(t, update.UpsertedAgents, 1)
slices.SortFunc(update.UpsertedAgents[0].Fqdn, strings.Compare)
slices.SortFunc(update.DeletedAgents[0].Fqdn, strings.Compare)
require.Equal(t, update, &PeerUpdate{
UpsertedWorkspaces: []*Workspace{
{Id: w1ID[:], Name: "w1", Status: Workspace_Status(proto.Workspace_STARTING)},
},
UpsertedAgents: []*Agent{
{
Id: w1a1ID[:], Name: "w1a1", WorkspaceId: w1ID[:],
Fqdn: []string{"w1.coder.", "w1a1.w1.me.coder.", "w1a1.w1.testy.coder."},
IpAddrs: []string{w1a1IP.String()},
LastHandshake: timestamppb.New(hsTime),
},
},
DeletedWorkspaces: []*Workspace{
{Id: w2ID[:], Name: "w2", Status: Workspace_Status(proto.Workspace_STOPPED)},
},
DeletedAgents: []*Agent{
{
Id: w2a1ID[:], Name: "w2a1", WorkspaceId: w2ID[:],
Fqdn: []string{"w2.coder.", "w2a1.w2.me.coder.", "w2a1.w2.testy.coder."},
IpAddrs: []string{w2a1IP.String()},
LastHandshake: nil,
},
},
})
}
func TestTunnel_sendAgentUpdate(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
mClock := quartz.NewMock(t)
wID1 := uuid.UUID{1}
aID1 := uuid.UUID{2}
aID2 := uuid.UUID{3}
hsTime := time.Now().Add(-time.Minute).UTC()
client := newFakeClient(ctx, t)
conn := newFakeConn(tailnet.WorkspaceUpdate{}, hsTime)
tun, mgr := setupTunnel(t, ctx, client, mClock)
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
TunnelFileDescriptor: 2,
CoderUrl: "https://coder.example.com",
ApiToken: "fakeToken",
},
},
})
resp = r
errCh <- err
}()
testutil.RequireSendCtx(ctx, t, client.ch, conn)
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
_, ok := resp.Msg.(*TunnelMessage_Start)
require.True(t, ok)
// Inform the tunnel of the initial state
err = tun.Update(tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
{
ID: wID1, Name: "w1", Status: proto.Workspace_STARTING,
},
},
UpsertedAgents: []*tailnet.Agent{
{
ID: aID1,
Name: "agent1",
WorkspaceID: wID1,
Hosts: map[dnsname.FQDN][]netip.Addr{
"agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
},
},
},
})
require.NoError(t, err)
req := testutil.RequireRecvCtx(ctx, t, mgr.requests)
require.Nil(t, req.msg.Rpc)
require.NotNil(t, req.msg.GetPeerUpdate())
require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1)
require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id)
// `sendAgentUpdate` produces the same PeerUpdate message until an agent
// update is received
for range 2 {
mClock.AdvanceNext()
// Then: the tunnel sends a PeerUpdate message of agent upserts,
// with the last handshake and latency set
req = testutil.RequireRecvCtx(ctx, t, mgr.requests)
require.Nil(t, req.msg.Rpc)
require.NotNil(t, req.msg.GetPeerUpdate())
require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1)
require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id)
require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime())
}
// Upsert a new agent
err = tun.Update(tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{},
UpsertedAgents: []*tailnet.Agent{
{
ID: aID2,
Name: "agent2",
WorkspaceID: wID1,
Hosts: map[dnsname.FQDN][]netip.Addr{
"agent2.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
},
},
},
})
require.NoError(t, err)
testutil.RequireRecvCtx(ctx, t, mgr.requests)
// The new update includes the new agent
mClock.AdvanceNext()
req = testutil.RequireRecvCtx(ctx, t, mgr.requests)
require.Nil(t, req.msg.Rpc)
require.NotNil(t, req.msg.GetPeerUpdate())
require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 2)
slices.SortFunc(req.msg.GetPeerUpdate().UpsertedAgents, func(a, b *Agent) int {
return strings.Compare(a.Name, b.Name)
})
require.Equal(t, aID1[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id)
require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime())
require.Equal(t, aID2[:], req.msg.GetPeerUpdate().UpsertedAgents[1].Id)
require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[1].LastHandshake.AsTime())
// Delete an agent
err = tun.Update(tailnet.WorkspaceUpdate{
DeletedAgents: []*tailnet.Agent{
{
ID: aID1,
Name: "agent1",
WorkspaceID: wID1,
Hosts: map[dnsname.FQDN][]netip.Addr{
"agent1.coder.": {netip.MustParseAddr("fd60:627a:a42b:0101::")},
},
},
},
})
require.NoError(t, err)
testutil.RequireRecvCtx(ctx, t, mgr.requests)
// The new update doesn't include the deleted agent
mClock.AdvanceNext()
req = testutil.RequireRecvCtx(ctx, t, mgr.requests)
require.Nil(t, req.msg.Rpc)
require.NotNil(t, req.msg.GetPeerUpdate())
require.Len(t, req.msg.GetPeerUpdate().UpsertedAgents, 1)
require.Equal(t, aID2[:], req.msg.GetPeerUpdate().UpsertedAgents[0].Id)
require.Equal(t, hsTime, req.msg.GetPeerUpdate().UpsertedAgents[0].LastHandshake.AsTime())
}
//nolint:revive // t takes precedence
func setupTunnel(t *testing.T, ctx context.Context, client *fakeClient, mClock quartz.Clock) (*Tunnel, *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]) {
mp, tp := net.Pipe()
t.Cleanup(func() { _ = mp.Close() })
t.Cleanup(func() { _ = tp.Close() })
logger := testutil.Logger(t)
var tun *Tunnel
var mgr *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]
errCh := make(chan error, 2)
go func() {
tunnel, err := NewTunnel(ctx, logger.Named("tunnel"), tp, client, WithClock(mClock))
tun = tunnel
errCh <- err
}()
go func() {
manager, err := newSpeaker[*ManagerMessage, *TunnelMessage](ctx, logger.Named("manager"), mp, SpeakerRoleManager, SpeakerRoleTunnel)
mgr = manager
errCh <- err
}()
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
mgr.start()
return tun, mgr
}