Files
coder/tailnet/test/integration/integration.go
Spike Curtis af4a6682b4 fix: use tailscale that avoids small MTU paths (#18323)
Fixes #15523

Uses latest https://github.com/coder/tailscale which includes https://github.com/coder/tailscale/pull/85 to stop selecting paths with small MTU for direct connections.

Also updates the tailnet integration test to reproduce the issue. The previous version had the 2 peers connected by a single veth, but this allows the OS to fragment the packet. In the new version, the 2 peers (and server) are all connected by a central router. The link between peer 1 and the router has an adjustable MTU. IPv6 does not allow packets to be fragmented by intermediate routers, so sending a too-large packet in this scenario forces the router to drop packets and reproduce the issue (without the tailscale changes).
2025-06-11 14:16:25 +04:00

751 lines
20 KiB
Go

//go:build linux
// +build linux
package integration
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/wgengine/capture"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw/loggermw"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)
type ClientNumber int
const (
ClientNumber1 ClientNumber = 1
ClientNumber2 ClientNumber = 2
)
type Client struct {
Number ClientNumber
ID uuid.UUID
ListenPort uint16
ShouldRunTests bool
TunnelSrc bool
}
var Client1 = Client{
Number: ClientNumber1,
ID: uuid.MustParse("00000000-0000-0000-0000-000000000001"),
ListenPort: client1Port,
ShouldRunTests: true,
TunnelSrc: true,
}
var Client2 = Client{
Number: ClientNumber2,
ID: uuid.MustParse("00000000-0000-0000-0000-000000000002"),
ListenPort: client2Port,
ShouldRunTests: false,
TunnelSrc: false,
}
type TestTopology struct {
Name string
NetworkingProvider NetworkingProvider
// Server is the server starter for the test. It is executed in the server
// subprocess.
Server ServerStarter
// ClientStarter.StartClient gets called in each client subprocess. It's expected to
// create the tailnet.Conn and ensure connectivity to it's peer.
ClientStarter ClientStarter
// RunTests is the main test function. It's called in each of the client
// subprocesses. If tests can only run once, they should check the client ID
// and return early if it's not the expected one.
RunTests func(t *testing.T, logger slog.Logger, serverURL *url.URL, conn *tailnet.Conn, me Client, peer Client)
}
type ServerStarter interface {
// StartServer should start the server and return once it's listening. It
// should not block once it's listening. Cleanup should be handled by
// t.Cleanup.
StartServer(t *testing.T, logger slog.Logger, listenAddr string)
}
type NetworkingProvider interface {
// SetupNetworking creates interfaces and network namespaces for the test.
// The most simple implementation is NetworkSetupDefault, which only creates
// a network namespace shared for all tests.
SetupNetworking(t *testing.T, logger slog.Logger) TestNetworking
}
type ClientStarter interface {
StartClient(t *testing.T, logger slog.Logger, serverURL *url.URL, derpMap *tailcfg.DERPMap, me Client, peer Client) *tailnet.Conn
}
type SimpleServerOptions struct {
// FailUpgradeDERP will make the DERP server fail to handle the initial DERP
// upgrade in a way that causes the client to fallback to
// DERP-over-WebSocket fallback automatically.
// Incompatible with DERPWebsocketOnly.
FailUpgradeDERP bool
// DERPWebsocketOnly will make the DERP server only accept WebSocket
// connections. If a DERP request is received that is not using WebSocket
// fallback, the test will fail.
// Incompatible with FailUpgradeDERP.
DERPWebsocketOnly bool
}
var _ ServerStarter = SimpleServerOptions{}
type connManager struct {
mu sync.Mutex
conns map[uuid.UUID]net.Conn
}
func (c *connManager) Add(id uuid.UUID, conn net.Conn) func() {
c.mu.Lock()
defer c.mu.Unlock()
c.conns[id] = conn
return func() {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.conns, id)
}
}
func (c *connManager) CloseAll() {
c.mu.Lock()
defer c.mu.Unlock()
for _, conn := range c.conns {
_ = conn.Close()
}
c.conns = make(map[uuid.UUID]net.Conn)
}
type derpServer struct {
http.Handler
srv *derp.Server
closeFn func()
}
func newDerpServer(t *testing.T, logger slog.Logger) *derpServer {
derpSrv := derp.NewServer(key.NewNode(), tailnet.Logger(logger.Named("derp")))
derpHandler, derpCloseFunc := tailnet.WithWebsocketSupport(derpSrv, derphttp.Handler(derpSrv))
t.Cleanup(derpCloseFunc)
return &derpServer{
srv: derpSrv,
Handler: derpHandler,
closeFn: derpCloseFunc,
}
}
func (s *derpServer) Close() {
s.srv.Close()
s.closeFn()
}
//nolint:revive
func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux {
coord := tailnet.NewCoordinator(logger)
var coordPtr atomic.Pointer[tailnet.Coordinator]
coordPtr.Store(&coord)
t.Cleanup(func() { _ = coord.Close() })
cm := connManager{
conns: make(map[uuid.UUID]net.Conn),
}
csvc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: 10 * time.Minute,
DERPMapFn: func() *tailcfg.DERPMap {
return &tailcfg.DERPMap{
// Clients will set their own based on their custom access URL.
Regions: map[int]*tailcfg.DERPRegion{},
}
},
NetworkTelemetryHandler: func(batch []*tailnetproto.TelemetryEvent) {},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
derpServer := atomic.Pointer[derpServer]{}
derpServer.Store(newDerpServer(t, logger))
t.Cleanup(func() {
derpServer.Load().Close()
})
r := chi.NewRouter()
r.Use(
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Debug(r.Context(), "start "+r.Method, slog.F("path", r.URL.Path), slog.F("remote_ip", r.RemoteAddr))
next.ServeHTTP(w, r)
})
},
tracing.StatusWriterMiddleware,
loggermw.Logger(logger),
)
r.Route("/derp", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
logger.Info(r.Context(), "start derp request", slog.F("path", r.URL.Path), slog.F("remote_ip", r.RemoteAddr))
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
if upgrade != "derp" && upgrade != "websocket" {
http.Error(w, "invalid DERP upgrade header", http.StatusBadRequest)
t.Errorf("invalid DERP upgrade header: %s", upgrade)
return
}
if o.FailUpgradeDERP && upgrade == "derp" {
// 4xx status codes will cause the client to fallback to
// DERP-over-WebSocket.
http.Error(w, "test derp upgrade failure", http.StatusBadRequest)
return
}
if o.DERPWebsocketOnly && upgrade != "websocket" {
logger.Error(r.Context(), "non-websocket DERP request received", slog.F("path", r.URL.Path), slog.F("remote_ip", r.RemoteAddr))
http.Error(w, "non-websocket DERP request received", http.StatusBadRequest)
t.Error("non-websocket DERP request received")
return
}
derpServer.Load().ServeHTTP(w, r)
})
r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
r.Post("/restart", func(w http.ResponseWriter, r *http.Request) {
oldServer := derpServer.Swap(newDerpServer(t, logger))
oldServer.Close()
w.WriteHeader(http.StatusOK)
})
})
// /restart?derp=[true|false]&coordinator=[true|false]
r.Post("/restart", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("derp") == "true" {
logger.Info(r.Context(), "killing DERP server")
oldServer := derpServer.Swap(newDerpServer(t, logger))
oldServer.Close()
logger.Info(r.Context(), "restarted DERP server")
}
if r.URL.Query().Get("coordinator") == "true" {
logger.Info(r.Context(), "simulating coordinator restart")
cm.CloseAll()
}
w.WriteHeader(http.StatusOK)
})
r.Get("/api/v2/workspaceagents/{id}/coordinate", func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
idStr := chi.URLParamFromCtx(ctx, "id")
id, err := uuid.Parse(idStr)
if err != nil {
logger.Warn(ctx, "bad agent ID passed in URL params", slog.F("id_str", idStr), slog.Error(err))
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
Message: "Bad agent id.",
Detail: err.Error(),
})
return
}
conn, err := websocket.Accept(w, r, nil)
if err != nil {
logger.Warn(ctx, "failed to accept websocket", slog.Error(err))
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
cleanFn := cm.Add(id, wsNetConn)
defer cleanFn()
err = csvc.ServeConnV2(ctx, wsNetConn, tailnet.StreamID{
Name: "client-" + id.String(),
ID: id,
Auth: tailnet.SingleTailnetCoordinateeAuth{},
})
if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, context.Canceled) {
logger.Warn(ctx, "failed to serve conn", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
})
return r
}
func (o SimpleServerOptions) StartServer(t *testing.T, logger slog.Logger, listenAddr string) {
srv := http.Server{
Addr: listenAddr,
Handler: o.Router(t, logger),
ReadTimeout: 10 * time.Second,
}
serveDone := make(chan struct{})
go func() {
defer close(serveDone)
err := srv.ListenAndServe()
if err != nil && !xerrors.Is(err, http.ErrServerClosed) {
t.Error("HTTP server error:", err)
}
}()
t.Cleanup(func() {
_ = srv.Close()
<-serveDone
})
}
type NGINXServerOptions struct {
SimpleServerOptions
}
var _ ServerStarter = NGINXServerOptions{}
func (o NGINXServerOptions) StartServer(t *testing.T, logger slog.Logger, listenAddr string) {
host, nginxPortStr, err := net.SplitHostPort(listenAddr)
require.NoError(t, err)
nginxPort, err := strconv.Atoi(nginxPortStr)
require.NoError(t, err)
serverPort := nginxPort + 1
serverListenAddr := net.JoinHostPort(host, strconv.Itoa(serverPort))
o.SimpleServerOptions.StartServer(t, logger, serverListenAddr)
startNginx(t, nginxPortStr, serverListenAddr)
}
func startNginx(t *testing.T, listenPort, serverAddr string) {
cfg := `events {}
http {
server {
listen ` + listenPort + `;
server_name _;
location / {
proxy_pass http://` + serverAddr + `;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Host $server_name;
}
}
}
`
dir := t.TempDir()
cfgPath := filepath.Join(dir, "nginx.conf")
err := os.WriteFile(cfgPath, []byte(cfg), 0o600)
require.NoError(t, err)
// ExecBackground will handle cleanup.
_, _ = ExecBackground(t, "server.nginx", nil, "nginx", []string{"-c", cfgPath})
}
type BasicClientStarter struct {
BlockEndpoints bool
DERPForceWebsockets bool
// WaitForConnection means wait for (any) peer connection before returning from StartClient
WaitForConnection bool
// WaitForConnection means wait for a direct peer connection before returning from StartClient
WaitForDirect bool
// Service is a network service (e.g. an echo server) to start on the client. If Wait* is set, the service is
// started prior to waiting.
Service NetworkService
LogPackets bool
}
type NetworkService interface {
StartService(t *testing.T, logger slog.Logger, conn *tailnet.Conn)
}
func (b BasicClientStarter) StartClient(t *testing.T, logger slog.Logger, serverURL *url.URL, derpMap *tailcfg.DERPMap, me, peer Client) *tailnet.Conn {
var hook capture.Callback
if b.LogPackets {
pktLogger := packetLogger{logger}
hook = pktLogger.LogPacket
}
conn := startClientOptions(t, logger, serverURL, me, peer, &tailnet.Options{
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.PrefixFromUUID(me.ID)},
DERPMap: derpMap,
BlockEndpoints: b.BlockEndpoints,
Logger: logger,
DERPForceWebSockets: b.DERPForceWebsockets,
ListenPort: me.ListenPort,
// These tests don't have internet connection, so we need to force
// magicsock to do anything.
ForceNetworkUp: true,
CaptureHook: hook,
})
if b.Service != nil {
b.Service.StartService(t, logger, conn)
}
if b.WaitForConnection || b.WaitForDirect {
// Wait for connection to be established.
peerIP := tailnet.TailscaleServicePrefix.AddrFromUUID(peer.ID)
require.Eventually(t, func() bool {
t.Log("attempting ping to peer to judge direct connection")
ctx := testutil.Context(t, testutil.WaitShort)
_, p2p, pong, err := conn.Ping(ctx, peerIP)
if err != nil {
t.Logf("ping failed: %v", err)
return false
}
if !p2p && b.WaitForDirect {
t.Log("ping succeeded, but not direct yet")
return false
}
t.Logf("ping succeeded, p2p=%t, endpoint=%s", p2p, pong.Endpoint)
return true
}, testutil.WaitLong, testutil.IntervalMedium)
}
return conn
}
const EchoPort = 2381
type UDPEchoService struct{}
func (UDPEchoService) StartService(t *testing.T, logger slog.Logger, _ *tailnet.Conn) {
// tailnet doesn't handle UDP connections "in-process" the way we do for TCP, so we need to listen in the OS,
// and tailnet will forward packets.
l, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv6zero, // all interfaces
Port: EchoPort,
})
require.NoError(t, err)
// set path MTU discovery so that we don't fragment the responses.
c, err := l.SyscallConn()
require.NoError(t, err)
var sockErr error
err = c.Control(func(fd uintptr) {
sockErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
})
require.NoError(t, err)
require.NoError(t, sockErr)
logger.Info(context.Background(), "started UDPEcho server")
t.Cleanup(func() {
lCloseErr := l.Close()
if lCloseErr != nil {
t.Logf("error closing UDPEcho listener: %v", lCloseErr)
}
})
go func() {
buf := make([]byte, 1500)
for {
n, remote, readErr := l.ReadFromUDP(buf)
if readErr != nil {
logger.Info(context.Background(), "error reading UDPEcho listener", slog.Error(readErr))
return
}
logger.Info(context.Background(), "received UDPEcho packet",
slog.F("len", n), slog.F("remote", remote))
n, writeErr := l.WriteToUDP(buf[:n], remote)
if writeErr != nil {
logger.Info(context.Background(), "error writing UDPEcho listener", slog.Error(writeErr))
return
}
logger.Info(context.Background(), "wrote UDPEcho packet",
slog.F("len", n), slog.F("remote", remote))
}
}()
}
func startClientOptions(t *testing.T, logger slog.Logger, serverURL *url.URL, me, peer Client, options *tailnet.Options) *tailnet.Conn {
u, err := serverURL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", me.ID.String()))
require.NoError(t, err)
//nolint:bodyclose
ws, _, err := websocket.Dial(context.Background(), u.String(), nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = ws.Close(websocket.StatusNormalClosure, "closing websocket")
})
client, err := tailnet.NewDRPCClient(
websocket.NetConn(context.Background(), ws, websocket.MessageBinary),
logger,
)
require.NoError(t, err)
coord, err := client.Coordinate(context.Background())
require.NoError(t, err)
conn, err := tailnet.NewConn(options)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
var coordination tailnet.CloserWaiter
if me.TunnelSrc {
ctrl := tailnet.NewTunnelSrcCoordController(logger, conn)
ctrl.AddDestination(peer.ID)
coordination = ctrl.New(coord)
} else {
// use the "Agent" controller so that we act as a tunnel destination and send "ReadyForHandshake" acks.
ctrl := tailnet.NewAgentCoordinationController(logger, conn)
coordination = ctrl.New(coord)
}
t.Cleanup(func() {
cctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
_ = coordination.Close(cctx)
})
return conn
}
func basicDERPMap(serverURLStr string) (*tailcfg.DERPMap, error) {
serverURL, err := url.Parse(serverURLStr)
if err != nil {
return nil, xerrors.Errorf("parse server URL %q: %w", serverURLStr, err)
}
portStr := serverURL.Port()
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, xerrors.Errorf("parse port %q: %w", portStr, err)
}
hostname := serverURL.Hostname()
ipv4 := "none"
ipv6 := "none"
ip, err := netip.ParseAddr(hostname)
if err == nil {
hostname = ""
if ip.Is4() {
ipv4 = ip.String()
}
if ip.Is6() {
ipv6 = ip.String()
}
}
return &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "test",
RegionName: "test server",
Nodes: []*tailcfg.DERPNode{
{
Name: "test0",
RegionID: 1,
HostName: hostname,
IPv4: ipv4,
IPv6: ipv6,
DERPPort: port,
STUNPort: -1,
ForceHTTP: true,
InsecureForTests: true,
},
},
},
},
}, nil
}
// ExecBackground starts a subprocess with the given flags and returns a
// channel that will receive the error when the subprocess exits. The returned
// function can be used to close the subprocess.
//
// processName is used to identify the subprocess in logs.
//
// Optionally, a network namespace can be passed to run the subprocess in.
//
// Do not call close then wait on the channel. Use the returned value from the
// function instead in this case.
//
// Cleanup is handled automatically if you don't care about monitoring the
// process manually.
func ExecBackground(t *testing.T, processName string, netNS *os.File, name string, args []string) (<-chan error, func() error) {
if netNS != nil {
// We use nsenter to enter the namespace.
// We can't use `setns` easily from Golang in the parent process because
// you can't execute the syscall in the forked child thread before it
// execs.
// We can't use `setns` easily from Golang in the child process because
// by the time you call it, the process has already created multiple
// threads.
args = append([]string{"--net=/proc/self/fd/3", name}, args...)
name = "nsenter"
}
cmd := exec.Command(name, args...)
if netNS != nil {
cmd.ExtraFiles = []*os.File{netNS}
}
out := &testWriter{
name: processName,
t: t,
}
t.Cleanup(out.Flush)
cmd.Stdout = out
cmd.Stderr = out
cmd.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGTERM,
}
err := cmd.Start()
require.NoError(t, err)
waitErr := make(chan error, 1)
go func() {
err := cmd.Wait()
if err != nil && strings.Contains(err.Error(), "signal: terminated") {
err = nil
}
waitErr <- err
close(waitErr)
}()
closeFn := func() error {
_ = cmd.Process.Signal(syscall.SIGTERM)
select {
case <-time.After(5 * time.Second):
_ = cmd.Process.Kill()
case err := <-waitErr:
return err
}
return <-waitErr
}
t.Cleanup(func() {
select {
case err := <-waitErr:
if err != nil {
t.Log("subprocess exited: " + err.Error())
}
return
default:
}
_ = closeFn()
})
return waitErr, closeFn
}
type testWriter struct {
mut sync.Mutex
name string
t *testing.T
capturedLines []string
}
func (w *testWriter) Write(p []byte) (n int, err error) {
w.mut.Lock()
defer w.mut.Unlock()
str := string(p)
split := strings.Split(str, "\n")
for _, s := range split {
if s == "" {
continue
}
// If a line begins with "\s*--- (PASS|FAIL)" or is just PASS or FAIL,
// then it's a test result line. We want to capture it and log it later.
trimmed := strings.TrimSpace(s)
if strings.HasPrefix(trimmed, "--- PASS") || strings.HasPrefix(trimmed, "--- FAIL") || trimmed == "PASS" || trimmed == "FAIL" {
// Also fail the test if we see a FAIL line.
if strings.Contains(trimmed, "FAIL") {
w.t.Errorf("subprocess logged test failure: %s: \t%s", w.name, s)
}
w.capturedLines = append(w.capturedLines, s)
continue
}
w.t.Logf("%s output: \t%s", w.name, s)
}
return len(p), nil
}
func (w *testWriter) Flush() {
w.mut.Lock()
defer w.mut.Unlock()
for _, s := range w.capturedLines {
w.t.Logf("%s output: \t%s", w.name, s)
}
w.capturedLines = nil
}
type packetLogger struct {
l slog.Logger
}
func (p packetLogger) LogPacket(path capture.Path, when time.Time, pkt []byte, _ packet.CaptureMeta) {
q := new(packet.Parsed)
q.Decode(pkt)
p.l.Info(context.Background(), "Packet",
slog.F("path", pathString(path)),
slog.F("when", when),
slog.F("decode", q.String()),
slog.F("len", len(pkt)),
)
}
func pathString(path capture.Path) string {
switch path {
case capture.FromLocal:
return "Local"
case capture.FromPeer:
return "Peer"
case capture.SynthesizedToLocal:
return "SynthesizedToLocal"
case capture.SynthesizedToPeer:
return "SynthesizedToPeer"
case capture.PathDisco:
return "Disco"
default:
return "<<UNKNOWN>>"
}
}