mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
chore: refactor TestServer_X11 to use inproc networking (#18564)
relates to #18263 Refactors the x11Forwarder to accept a networking `interface` that we can fake out for testing. This isolates the unit tests from other processes listening in the port range used by X11 forwarding. This will become extremely important in up-stack PRs where we listen on every port in the range and need to control which ports have conflicts.
This commit is contained in:
@ -117,6 +117,10 @@ type Config struct {
|
|||||||
// Note that this is different from the devcontainers feature, which uses
|
// Note that this is different from the devcontainers feature, which uses
|
||||||
// subagents.
|
// subagents.
|
||||||
ExperimentalContainers bool
|
ExperimentalContainers bool
|
||||||
|
// X11Net allows overriding the networking implementation used for X11
|
||||||
|
// forwarding listeners. When nil, a default implementation backed by the
|
||||||
|
// standard library networking package is used.
|
||||||
|
X11Net X11Network
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@ -196,6 +200,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
|||||||
displayOffset: *config.X11DisplayOffset,
|
displayOffset: *config.X11DisplayOffset,
|
||||||
sessions: make(map[*x11Session]struct{}),
|
sessions: make(map[*x11Session]struct{}),
|
||||||
connections: make(map[net.Conn]struct{}),
|
connections: make(map[net.Conn]struct{}),
|
||||||
|
network: func() X11Network {
|
||||||
|
if config.X11Net != nil {
|
||||||
|
return config.X11Net
|
||||||
|
}
|
||||||
|
return osNet{}
|
||||||
|
}(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,12 +37,30 @@ const (
|
|||||||
X11MaxPort = X11StartPort + X11MaxDisplays
|
X11MaxPort = X11StartPort + X11MaxDisplays
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// X11Network abstracts the creation of network listeners for X11 forwarding.
|
||||||
|
// It is intended mainly for testing; production code uses the default
|
||||||
|
// implementation backed by the operating system networking stack.
|
||||||
|
type X11Network interface {
|
||||||
|
Listen(network, address string) (net.Listener, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// osNet is the default X11Network implementation that uses the standard
|
||||||
|
// library network stack.
|
||||||
|
type osNet struct{}
|
||||||
|
|
||||||
|
func (osNet) Listen(network, address string) (net.Listener, error) {
|
||||||
|
return net.Listen(network, address)
|
||||||
|
}
|
||||||
|
|
||||||
type x11Forwarder struct {
|
type x11Forwarder struct {
|
||||||
logger slog.Logger
|
logger slog.Logger
|
||||||
x11HandlerErrors *prometheus.CounterVec
|
x11HandlerErrors *prometheus.CounterVec
|
||||||
fs afero.Fs
|
fs afero.Fs
|
||||||
displayOffset int
|
displayOffset int
|
||||||
|
|
||||||
|
// network creates X11 listener sockets. Defaults to osNet{}.
|
||||||
|
network X11Network
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
sessions map[*x11Session]struct{}
|
sessions map[*x11Session]struct{}
|
||||||
connections map[net.Conn]struct{}
|
connections map[net.Conn]struct{}
|
||||||
@ -147,26 +165,27 @@ func (x *x11Forwarder) listenForConnections(
|
|||||||
x.closeAndRemoveSession(session)
|
x.closeAndRemoveSession(session)
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpConn, ok := conn.(*net.TCPConn)
|
var originAddr string
|
||||||
if !ok {
|
var originPort uint32
|
||||||
x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
|
|
||||||
_ = conn.Close()
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||||
continue
|
if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok {
|
||||||
|
originAddr = tcpAddr.IP.String()
|
||||||
|
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
|
||||||
|
originPort = uint32(tcpAddr.Port)
|
||||||
}
|
}
|
||||||
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
|
}
|
||||||
if !ok {
|
// Fallback values for in-memory or non-TCP connections.
|
||||||
x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
|
if originAddr == "" {
|
||||||
_ = conn.Close()
|
originAddr = "127.0.0.1"
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
|
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
|
||||||
OriginatorAddress string
|
OriginatorAddress string
|
||||||
OriginatorPort uint32
|
OriginatorPort uint32
|
||||||
}{
|
}{
|
||||||
OriginatorAddress: tcpAddr.IP.String(),
|
OriginatorAddress: originAddr,
|
||||||
// #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
|
OriginatorPort: originPort,
|
||||||
OriginatorPort: uint32(tcpAddr.Port),
|
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
|
x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
|
||||||
@ -287,13 +306,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
|
|||||||
// createX11Listener creates a listener for X11 forwarding, it will use
|
// createX11Listener creates a listener for X11 forwarding, it will use
|
||||||
// the next available port starting from X11StartPort and displayOffset.
|
// the next available port starting from X11StartPort and displayOffset.
|
||||||
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
|
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
|
||||||
var lc net.ListenConfig
|
|
||||||
// Look for an open port to listen on.
|
// Look for an open port to listen on.
|
||||||
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
|
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return nil, -1, ctx.Err()
|
return nil, -1, ctx.Err()
|
||||||
}
|
}
|
||||||
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
|
|
||||||
|
ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
display = port - X11StartPort
|
display = port - X11StartPort
|
||||||
return ln, display, nil
|
return ln, display, nil
|
||||||
|
@ -3,7 +3,6 @@ package agentssh_test
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -32,10 +31,19 @@ func TestServer_X11(t *testing.T) {
|
|||||||
t.Skip("X11 forwarding is only supported on Linux")
|
t.Skip("X11 forwarding is only supported on Linux")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
fs := afero.NewMemMapFs()
|
fs := afero.NewMemMapFs()
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
|
|
||||||
|
// Use in-process networking for X11 forwarding.
|
||||||
|
inproc := testutil.NewInProcNet()
|
||||||
|
|
||||||
|
// Create server config with custom X11 listener.
|
||||||
|
cfg := &agentssh.Config{
|
||||||
|
X11Net: inproc,
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
err = s.UpdateHostSigner(42)
|
err = s.UpdateHostSigner(42)
|
||||||
@ -93,17 +101,15 @@ func TestServer_X11(t *testing.T) {
|
|||||||
|
|
||||||
x11Chans := c.HandleChannelOpen("x11")
|
x11Chans := c.HandleChannelOpen("x11")
|
||||||
payload := "hello world"
|
payload := "hello world"
|
||||||
require.Eventually(t, func() bool {
|
go func() {
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
|
conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)))
|
||||||
if err == nil {
|
assert.NoError(t, err)
|
||||||
_, err = conn.Write([]byte(payload))
|
_, err = conn.Write([]byte(payload))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}()
|
||||||
return err == nil
|
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
|
||||||
|
|
||||||
x11 := <-x11Chans
|
x11 := testutil.RequireReceive(ctx, t, x11Chans)
|
||||||
ch, reqs, err := x11.Accept()
|
ch, reqs, err := x11.Accept()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
go gossh.DiscardRequests(reqs)
|
go gossh.DiscardRequests(reqs)
|
||||||
|
Reference in New Issue
Block a user