From a5bfb200fc5c6eb05a8a5e5a27ef7893ea1ba439 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 27 Jun 2025 14:56:33 +0400 Subject: [PATCH] chore: refactor TestServer_X11 to use inproc networking (#18564) relates to #18263 Refactors the x11Forwarder to accept a networking `interface` that we can fake out for testing. This isolates the unit tests from other processes listening in the port range used by X11 forwarding. This will become extremely important in up-stack PRs where we listen on every port in the range and need to control which ports have conflicts. --- agent/agentssh/agentssh.go | 10 ++++++++ agent/agentssh/x11.go | 49 ++++++++++++++++++++++++++------------ agent/agentssh/x11_test.go | 32 +++++++++++++++---------- 3 files changed, 63 insertions(+), 28 deletions(-) diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index ec682a735c..6e3760c643 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -117,6 +117,10 @@ type Config struct { // Note that this is different from the devcontainers feature, which uses // subagents. ExperimentalContainers bool + // X11Net allows overriding the networking implementation used for X11 + // forwarding listeners. When nil, a default implementation backed by the + // standard library networking package is used. + X11Net X11Network } type Server struct { @@ -196,6 +200,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom displayOffset: *config.X11DisplayOffset, sessions: make(map[*x11Session]struct{}), connections: make(map[net.Conn]struct{}), + network: func() X11Network { + if config.X11Net != nil { + return config.X11Net + } + return osNet{} + }(), }, } diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index 8c23d32bfa..05d9f866c1 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -37,12 +37,30 @@ const ( X11MaxPort = X11StartPort + X11MaxDisplays ) +// X11Network abstracts the creation of network listeners for X11 forwarding. +// It is intended mainly for testing; production code uses the default +// implementation backed by the operating system networking stack. +type X11Network interface { + Listen(network, address string) (net.Listener, error) +} + +// osNet is the default X11Network implementation that uses the standard +// library network stack. +type osNet struct{} + +func (osNet) Listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} + type x11Forwarder struct { logger slog.Logger x11HandlerErrors *prometheus.CounterVec fs afero.Fs displayOffset int + // network creates X11 listener sockets. Defaults to osNet{}. + network X11Network + mu sync.Mutex sessions map[*x11Session]struct{} connections map[net.Conn]struct{} @@ -147,26 +165,27 @@ func (x *x11Forwarder) listenForConnections( x.closeAndRemoveSession(session) } - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn)) - _ = conn.Close() - continue + var originAddr string + var originPort uint32 + + if tcpConn, ok := conn.(*net.TCPConn); ok { + if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok { + originAddr = tcpAddr.IP.String() + // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) + originPort = uint32(tcpAddr.Port) + } } - tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr) - if !ok { - x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr())) - _ = conn.Close() - continue + // Fallback values for in-memory or non-TCP connections. + if originAddr == "" { + originAddr = "127.0.0.1" } channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct { OriginatorAddress string OriginatorPort uint32 }{ - OriginatorAddress: tcpAddr.IP.String(), - // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) - OriginatorPort: uint32(tcpAddr.Port), + OriginatorAddress: originAddr, + OriginatorPort: originPort, })) if err != nil { x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err)) @@ -287,13 +306,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() { // createX11Listener creates a listener for X11 forwarding, it will use // the next available port starting from X11StartPort and displayOffset. func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) { - var lc net.ListenConfig // Look for an open port to listen on. for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ { if ctx.Err() != nil { return nil, -1, ctx.Err() } - ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) + + ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port)) if err == nil { display = port - X11StartPort return ln, display, nil diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index 39440da712..a680c088de 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -3,7 +3,6 @@ package agentssh_test import ( "bufio" "bytes" - "context" "encoding/hex" "fmt" "net" @@ -32,10 +31,19 @@ func TestServer_X11(t *testing.T) { t.Skip("X11 forwarding is only supported on Linux") } - ctx := context.Background() + ctx := testutil.Context(t, testutil.WaitShort) logger := testutil.Logger(t) fs := afero.NewMemMapFs() - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{}) + + // Use in-process networking for X11 forwarding. + inproc := testutil.NewInProcNet() + + // Create server config with custom X11 listener. + cfg := &agentssh.Config{ + X11Net: inproc, + } + + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg) require.NoError(t, err) defer s.Close() err = s.UpdateHostSigner(42) @@ -93,17 +101,15 @@ func TestServer_X11(t *testing.T) { x11Chans := c.HandleChannelOpen("x11") payload := "hello world" - require.Eventually(t, func() bool { - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)) - if err == nil { - _, err = conn.Write([]byte(payload)) - assert.NoError(t, err) - _ = conn.Close() - } - return err == nil - }, testutil.WaitShort, testutil.IntervalFast) + go func() { + conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))) + assert.NoError(t, err) + _, err = conn.Write([]byte(payload)) + assert.NoError(t, err) + _ = conn.Close() + }() - x11 := <-x11Chans + x11 := testutil.RequireReceive(ctx, t, x11Chans) ch, reqs, err := x11.Accept() require.NoError(t, err) go gossh.DiscardRequests(reqs)