fix: use fake local network for port-forward tests (#11119)

Fixes #10979

Testing code that listens on a specific port has created a long battle with flakes.  Previous attempts to deal with this include opening a listener on a port chosen by the OS, then closing the listener, noting the port and starting the test with that port.
This still flakes, notably in macOS which has a proclivity to reuse ports quickly.

Instead of fighting with the chaos that is an OS networking stack, this PR fakes the host networking in tests.

I've taken a small step here, only faking out the Listen() calls that port-forward makes, but I think over time we should be transitioning all networking the CLI does to an abstract interface so we can fake it.  This allows us to run in parallel without flakes and
presents an opportunity to test error paths as well.
This commit is contained in:
Spike Curtis
2023-12-11 14:51:56 +04:00
committed by GitHub
parent 37f6b38d53
commit 50575e1a9a
4 changed files with 191 additions and 106 deletions

View File

@ -189,6 +189,7 @@ type Invocation struct {
Stderr io.Writer Stderr io.Writer
Stdin io.Reader Stdin io.Reader
Logger slog.Logger Logger slog.Logger
Net Net
// testing // testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
@ -203,6 +204,7 @@ func (inv *Invocation) WithOS() *Invocation {
i.Stdin = os.Stdin i.Stdin = os.Stdin
i.Args = os.Args[1:] i.Args = os.Args[1:]
i.Environ = ParseEnviron(os.Environ(), "") i.Environ = ParseEnviron(os.Environ(), "")
i.Net = osNet{}
}) })
} }

50
cli/clibase/net.go Normal file
View File

@ -0,0 +1,50 @@
package clibase
import (
"net"
"strconv"
"github.com/pion/udp"
"golang.org/x/xerrors"
)
// Net abstracts CLI commands interacting with the operating system networking.
//
// At present, it covers opening local listening sockets, since doing this
// in testing is a challenge without flakes, since it's hard to pick a port we
// know a priori will be free.
type Net interface {
// Listen has the same semantics as `net.Listen` but also supports `udp`
Listen(network, address string) (net.Listener, error)
}
// osNet is an implementation that call the real OS for networking.
type osNet struct{}
func (osNet) Listen(network, address string) (net.Listener, error) {
switch network {
case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
return net.Listen(network, address)
case "udp":
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, xerrors.Errorf("split %q: %w", address, err)
}
var portInt int
portInt, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, address, err)
}
// Use pion here so that we get a stream-style net.Conn listener, instead
// of a packet-oriented connection that can read and write to multiple
// addresses.
return udp.Listen(network, &net.UDPAddr{
IP: net.ParseIP(host),
Port: portInt,
})
default:
return nil, xerrors.Errorf("unknown listen network %q", network)
}
}

View File

@ -12,7 +12,6 @@ import (
"sync" "sync"
"syscall" "syscall"
"github.com/pion/udp"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
@ -121,6 +120,7 @@ func (r *RootCmd) portForward() *clibase.Cmd {
wg = new(sync.WaitGroup) wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs)) listeners = make([]net.Listener, len(specs))
closeAllListeners = func() { closeAllListeners = func() {
logger.Debug(ctx, "closing all listeners")
for _, l := range listeners { for _, l := range listeners {
if l == nil { if l == nil {
continue continue
@ -134,6 +134,7 @@ func (r *RootCmd) portForward() *clibase.Cmd {
for i, spec := range specs { for i, spec := range specs {
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger) l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
if err != nil { if err != nil {
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
return err return err
} }
listeners[i] = l listeners[i] = l
@ -151,8 +152,10 @@ func (r *RootCmd) portForward() *clibase.Cmd {
select { select {
case <-ctx.Done(): case <-ctx.Done():
logger.Debug(ctx, "command context expired waiting for signal", slog.Error(ctx.Err()))
closeErr = ctx.Err() closeErr = ctx.Err()
case <-sigs: case sig := <-sigs:
logger.Debug(ctx, "received signal", slog.F("signal", sig))
_, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections") _, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections")
} }
@ -161,6 +164,7 @@ func (r *RootCmd) portForward() *clibase.Cmd {
}() }()
conn.AwaitReachable(ctx) conn.AwaitReachable(ctx)
logger.Debug(ctx, "read to accept connections to forward")
_, _ = fmt.Fprintln(inv.Stderr, "Ready!") _, _ = fmt.Fprintln(inv.Stderr, "Ready!")
wg.Wait() wg.Wait()
return closeErr return closeErr
@ -198,33 +202,7 @@ func listenAndPortForward(
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress)) logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
var ( l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
l net.Listener
err error
)
switch spec.listenNetwork {
case "tcp":
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
case "udp":
var host, port string
host, port, err = net.SplitHostPort(spec.listenAddress)
if err != nil {
return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err)
}
var portInt int
portInt, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err)
}
l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{
IP: net.ParseIP(host),
Port: portInt,
})
default:
return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
}
if err != nil { if err != nil {
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err) return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/pion/udp" "github.com/pion/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/agent/agenttest"
@ -45,47 +46,35 @@ func TestPortForward_None(t *testing.T) {
pty.ExpectMatch("port-forward <workspace>") pty.ExpectMatch("port-forward <workspace>")
} }
//nolint:tparallel,paralleltest // Subtests require setup that must not be done in parallel.
func TestPortForward(t *testing.T) { func TestPortForward(t *testing.T) {
t.Parallel()
cases := []struct { cases := []struct {
name string name string
network string network string
// The flag to pass to `coder port-forward X` to port-forward this type // The flag(s) to pass to `coder port-forward X` to port-forward this type
// of connection. Has two format args (both strings), the first is the // of connection. Has one format arg (string) for the remote address.
// local address and the second is the remote address. flag []string
flag string
// setupRemote creates a "remote" listener to emulate a service in the // setupRemote creates a "remote" listener to emulate a service in the
// workspace. // workspace.
setupRemote func(t *testing.T) net.Listener setupRemote func(t *testing.T) net.Listener
// setupLocal returns an available port that the // the local address(es) to "dial"
// port-forward command will listen on "locally". Returns the address localAddress []string
// you pass to net.Dial, and the port/path you pass to `coder
// port-forward`.
setupLocal func(t *testing.T) (string, string)
}{ }{
{ {
name: "TCP", name: "TCP",
network: "tcp", network: "tcp",
flag: "--tcp=%v:%v", flag: []string{"--tcp=5555:%v", "--tcp=6666:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: func(t *testing.T) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
return l return l
}, },
setupLocal: func(t *testing.T) (string, string) { localAddress: []string{"127.0.0.1:5555", "127.0.0.1:6666"},
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener to generate random port")
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoErrorf(t, err, "split TCP address %q", l.Addr().String())
return l.Addr().String(), port
},
}, },
{ {
name: "UDP", name: "UDP",
network: "udp", network: "udp",
flag: "--udp=%v:%v", flag: []string{"--udp=7777:%v", "--udp=8888:%v"},
setupRemote: func(t *testing.T) net.Listener { setupRemote: func(t *testing.T) net.Listener {
addr := net.UDPAddr{ addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
@ -95,38 +84,17 @@ func TestPortForward(t *testing.T) {
require.NoError(t, err, "create UDP listener") require.NoError(t, err, "create UDP listener")
return l return l
}, },
setupLocal: func(t *testing.T) (string, string) { localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"},
addr := net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
}
l, err := udp.Listen("udp", &addr)
require.NoError(t, err, "create UDP listener to generate random port")
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoErrorf(t, err, "split UDP address %q", l.Addr().String())
return l.Addr().String(), port
},
}, },
{ {
name: "TCPWithAddress", name: "TCPWithAddress",
network: "tcp", network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"},
flag: "--tcp=%v:%v",
setupRemote: func(t *testing.T) net.Listener { setupRemote: func(t *testing.T) net.Listener {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener") require.NoError(t, err, "create TCP listener")
return l return l
}, },
setupLocal: func(t *testing.T) (string, string) { localAddress: []string{"10.10.10.99:9999", "10.10.10.10:1010"},
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err, "create TCP listener to generate random port")
defer l.Close()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoErrorf(t, err, "split TCP address %q", l.Addr().String())
return l.Addr().String(), fmt.Sprint("0.0.0.0:", port)
},
}, },
} }
@ -141,16 +109,12 @@ func TestPortForward(t *testing.T) {
for _, c := range cases { for _, c := range cases {
c := c c := c
// No parallel tests here because setupLocal reserves
// a free open port which is not guaranteed to be free
// between the listener closing and port-forward ready.
//nolint:tparallel,paralleltest
t.Run(c.name+"_OnePort", func(t *testing.T) { t.Run(c.name+"_OnePort", func(t *testing.T) {
t.Parallel()
p1 := setupTestListener(t, c.setupRemote(t)) p1 := setupTestListener(t, c.setupRemote(t))
// Create a flag that forwards from local to listener 1. // Create a flag that forwards from local to listener 1.
localAddress, localFlag := c.setupLocal(t) flag := fmt.Sprintf(c.flag[0], p1)
flag := fmt.Sprintf(c.flag, localFlag, p1)
// Launch port-forward in a goroutine so we can start dialing // Launch port-forward in a goroutine so we can start dialing
// the "local" listener. // the "local" listener.
@ -160,21 +124,27 @@ func TestPortForward(t *testing.T) {
inv.Stdin = pty.Input() inv.Stdin = pty.Input()
inv.Stdout = pty.Output() inv.Stdout = pty.Output()
inv.Stderr = pty.Output() inv.Stderr = pty.Output()
iNet := newInProcNet()
inv.Net = iNet
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
errC := make(chan error) errC := make(chan error)
go func() { go func() {
errC <- inv.WithContext(ctx).Run() err := inv.WithContext(ctx).Run()
t.Logf("command complete; err=%s", err.Error())
errC <- err
}() }()
pty.ExpectMatchContext(ctx, "Ready!") pty.ExpectMatchContext(ctx, "Ready!")
// Open two connections simultaneously and test them out of // Open two connections simultaneously and test them out of
// sync. // sync.
d := net.Dialer{Timeout: testutil.WaitShort} dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort)
c1, err := d.DialContext(ctx, c.network, localAddress) defer dialCtxCancel()
c1, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]})
require.NoError(t, err, "open connection 1 to 'local' listener") require.NoError(t, err, "open connection 1 to 'local' listener")
defer c1.Close() defer c1.Close()
c2, err := d.DialContext(ctx, c.network, localAddress) c2, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]})
require.NoError(t, err, "open connection 2 to 'local' listener") require.NoError(t, err, "open connection 2 to 'local' listener")
defer c2.Close() defer c2.Close()
testDial(t, c2) testDial(t, c2)
@ -185,21 +155,16 @@ func TestPortForward(t *testing.T) {
require.ErrorIs(t, err, context.Canceled) require.ErrorIs(t, err, context.Canceled)
}) })
// No parallel tests here because setupLocal reserves
// a free open port which is not guaranteed to be free
// between the listener closing and port-forward ready.
//nolint:tparallel,paralleltest
t.Run(c.name+"_TwoPorts", func(t *testing.T) { t.Run(c.name+"_TwoPorts", func(t *testing.T) {
t.Parallel()
var ( var (
p1 = setupTestListener(t, c.setupRemote(t)) p1 = setupTestListener(t, c.setupRemote(t))
p2 = setupTestListener(t, c.setupRemote(t)) p2 = setupTestListener(t, c.setupRemote(t))
) )
// Create a flags for listener 1 and listener 2. // Create a flags for listener 1 and listener 2.
localAddress1, localFlag1 := c.setupLocal(t) flag1 := fmt.Sprintf(c.flag[0], p1)
localAddress2, localFlag2 := c.setupLocal(t) flag2 := fmt.Sprintf(c.flag[1], p2)
flag1 := fmt.Sprintf(c.flag, localFlag1, p1)
flag2 := fmt.Sprintf(c.flag, localFlag2, p2)
// Launch port-forward in a goroutine so we can start dialing // Launch port-forward in a goroutine so we can start dialing
// the "local" listeners. // the "local" listeners.
@ -209,6 +174,9 @@ func TestPortForward(t *testing.T) {
inv.Stdin = pty.Input() inv.Stdin = pty.Input()
inv.Stdout = pty.Output() inv.Stdout = pty.Output()
inv.Stderr = pty.Output() inv.Stderr = pty.Output()
iNet := newInProcNet()
inv.Net = iNet
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
errC := make(chan error) errC := make(chan error)
@ -219,11 +187,12 @@ func TestPortForward(t *testing.T) {
// Open a connection to both listener 1 and 2 simultaneously and // Open a connection to both listener 1 and 2 simultaneously and
// then test them out of order. // then test them out of order.
d := net.Dialer{Timeout: testutil.WaitShort} dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort)
c1, err := d.DialContext(ctx, c.network, localAddress1) defer dialCtxCancel()
c1, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]})
require.NoError(t, err, "open connection 1 to 'local' listener 1") require.NoError(t, err, "open connection 1 to 'local' listener 1")
defer c1.Close() defer c1.Close()
c2, err := d.DialContext(ctx, c.network, localAddress2) c2, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[1]})
require.NoError(t, err, "open connection 2 to 'local' listener 2") require.NoError(t, err, "open connection 2 to 'local' listener 2")
defer c2.Close() defer c2.Close()
testDial(t, c2) testDial(t, c2)
@ -235,12 +204,8 @@ func TestPortForward(t *testing.T) {
}) })
} }
// Test doing TCP and UDP at the same time.
// No parallel tests here because setupLocal reserves
// a free open port which is not guaranteed to be free
// between the listener closing and port-forward ready.
//nolint:tparallel,paralleltest
t.Run("All", func(t *testing.T) { t.Run("All", func(t *testing.T) {
t.Parallel()
var ( var (
dials = []addr{} dials = []addr{}
flags = []string{} flags = []string{}
@ -250,12 +215,11 @@ func TestPortForward(t *testing.T) {
for _, c := range cases { for _, c := range cases {
p := setupTestListener(t, c.setupRemote(t)) p := setupTestListener(t, c.setupRemote(t))
localAddress, localFlag := c.setupLocal(t)
dials = append(dials, addr{ dials = append(dials, addr{
network: c.network, network: c.network,
addr: localAddress, addr: c.localAddress[0],
}) })
flags = append(flags, fmt.Sprintf(c.flag, localFlag, p)) flags = append(flags, fmt.Sprintf(c.flag[0], p))
} }
// Launch port-forward in a goroutine so we can start dialing // Launch port-forward in a goroutine so we can start dialing
@ -264,6 +228,9 @@ func TestPortForward(t *testing.T) {
clitest.SetupConfig(t, member, root) clitest.SetupConfig(t, member, root)
pty := ptytest.New(t).Attach(inv) pty := ptytest.New(t).Attach(inv)
inv.Stderr = pty.Output() inv.Stderr = pty.Output()
iNet := newInProcNet()
inv.Net = iNet
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
errC := make(chan error) errC := make(chan error)
@ -274,11 +241,12 @@ func TestPortForward(t *testing.T) {
// Open connections to all items in the "dial" array. // Open connections to all items in the "dial" array.
var ( var (
d = net.Dialer{Timeout: testutil.WaitShort} dialCtx, dialCtxCancel = context.WithTimeout(ctx, testutil.WaitShort)
conns = make([]net.Conn, len(dials)) conns = make([]net.Conn, len(dials))
) )
defer dialCtxCancel()
for i, a := range dials { for i, a := range dials {
c, err := d.DialContext(ctx, a.network, a.addr) c, err := iNet.dial(dialCtx, a)
require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1) require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1)
t.Cleanup(func() { t.Cleanup(func() {
_ = c.Close() _ = c.Close()
@ -396,3 +364,90 @@ type addr struct {
network string network string
addr string addr string
} }
func (a addr) Network() string {
return a.network
}
func (a addr) Address() string {
return a.addr
}
func (a addr) String() string {
return a.network + "|" + a.addr
}
type inProcNet struct {
sync.Mutex
listeners map[addr]*inProcListener
}
type inProcListener struct {
c chan net.Conn
n *inProcNet
a addr
o sync.Once
}
func newInProcNet() *inProcNet {
return &inProcNet{listeners: make(map[addr]*inProcListener)}
}
func (n *inProcNet) Listen(network, address string) (net.Listener, error) {
a := addr{network, address}
n.Lock()
defer n.Unlock()
if _, ok := n.listeners[a]; ok {
return nil, xerrors.New("busy")
}
l := newInProcListener(n, a)
n.listeners[a] = l
return l, nil
}
func (n *inProcNet) dial(ctx context.Context, a addr) (net.Conn, error) {
n.Lock()
defer n.Unlock()
l, ok := n.listeners[a]
if !ok {
return nil, xerrors.Errorf("nothing listening on %s", a)
}
x, y := net.Pipe()
select {
case <-ctx.Done():
return nil, ctx.Err()
case l.c <- x:
return y, nil
}
}
func newInProcListener(n *inProcNet, a addr) *inProcListener {
return &inProcListener{
c: make(chan net.Conn),
n: n,
a: a,
}
}
func (l *inProcListener) Accept() (net.Conn, error) {
c, ok := <-l.c
if !ok {
return nil, net.ErrClosed
}
return c, nil
}
func (l *inProcListener) Close() error {
l.o.Do(func() {
l.n.Lock()
defer l.n.Unlock()
delete(l.n.listeners, l.a)
close(l.c)
})
return nil
}
func (l *inProcListener) Addr() net.Addr {
return l.a
}