chore: move InProcNet to testutil (#18563)

Moves `InProcNet` to `testutil` so that it can be reused by X11 forwarding tests (see up stack PRs).
This commit is contained in:
Spike Curtis
2025-06-27 14:42:22 +04:00
committed by GitHub
parent 6bebfd0ec6
commit abcf3df71a
2 changed files with 117 additions and 108 deletions

View File

@ -13,7 +13,6 @@ import (
"github.com/pion/udp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
@ -161,7 +160,7 @@ func TestPortForward(t *testing.T) {
inv.Stdout = pty.Output()
inv.Stderr = pty.Output()
iNet := newInProcNet()
iNet := testutil.NewInProcNet()
inv.Net = iNet
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@ -177,10 +176,10 @@ func TestPortForward(t *testing.T) {
// sync.
dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort)
defer dialCtxCancel()
c1, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]})
c1, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0]))
require.NoError(t, err, "open connection 1 to 'local' listener")
defer c1.Close()
c2, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]})
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0]))
require.NoError(t, err, "open connection 2 to 'local' listener")
defer c2.Close()
testDial(t, c2)
@ -218,7 +217,7 @@ func TestPortForward(t *testing.T) {
inv.Stdout = pty.Output()
inv.Stderr = pty.Output()
iNet := newInProcNet()
iNet := testutil.NewInProcNet()
inv.Net = iNet
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@ -232,10 +231,10 @@ func TestPortForward(t *testing.T) {
// then test them out of order.
dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort)
defer dialCtxCancel()
c1, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[0]})
c1, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[0]))
require.NoError(t, err, "open connection 1 to 'local' listener 1")
defer c1.Close()
c2, err := iNet.dial(dialCtx, addr{c.network, c.localAddress[1]})
c2, err := iNet.Dial(dialCtx, testutil.NewAddr(c.network, c.localAddress[1]))
require.NoError(t, err, "open connection 2 to 'local' listener 2")
defer c2.Close()
testDial(t, c2)
@ -257,7 +256,7 @@ func TestPortForward(t *testing.T) {
t.Run("All", func(t *testing.T) {
t.Parallel()
var (
dials = []addr{}
dials = []testutil.Addr{}
flags = []string{}
)
@ -265,10 +264,7 @@ func TestPortForward(t *testing.T) {
for _, c := range cases {
p := setupTestListener(t, c.setupRemote(t))
dials = append(dials, addr{
network: c.network,
addr: c.localAddress[0],
})
dials = append(dials, testutil.NewAddr(c.network, c.localAddress[0]))
flags = append(flags, fmt.Sprintf(c.flag[0], p))
}
@ -279,7 +275,7 @@ func TestPortForward(t *testing.T) {
pty := ptytest.New(t).Attach(inv)
inv.Stderr = pty.Output()
iNet := newInProcNet()
iNet := testutil.NewInProcNet()
inv.Net = iNet
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
@ -296,7 +292,7 @@ func TestPortForward(t *testing.T) {
)
defer dialCtxCancel()
for i, a := range dials {
c, err := iNet.dial(dialCtx, a)
c, err := iNet.Dial(dialCtx, a)
require.NoErrorf(t, err, "open connection %v to 'local' listener %v", i+1, i+1)
t.Cleanup(func() {
_ = c.Close()
@ -340,7 +336,7 @@ func TestPortForward(t *testing.T) {
inv.Stdout = pty.Output()
inv.Stderr = pty.Output()
iNet := newInProcNet()
iNet := testutil.NewInProcNet()
inv.Net = iNet
// listen on port 5555 on IPv6 so it's busy when we try to port forward
@ -361,7 +357,7 @@ func TestPortForward(t *testing.T) {
// Test IPv4 still works
dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort)
defer dialCtxCancel()
c1, err := iNet.dial(dialCtx, addr{"tcp", "127.0.0.1:5555"})
c1, err := iNet.Dial(dialCtx, testutil.NewAddr("tcp", "127.0.0.1:5555"))
require.NoError(t, err, "open connection 1 to 'local' listener")
defer c1.Close()
testDial(t, c1)
@ -473,95 +469,3 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "payload length does not match")
}
type addr struct {
network string
addr string
}
func (a addr) Network() string {
return a.network
}
func (a addr) Address() string {
return a.addr
}
func (a addr) String() string {
return a.network + "|" + a.addr
}
type inProcNet struct {
sync.Mutex
listeners map[addr]*inProcListener
}
type inProcListener struct {
c chan net.Conn
n *inProcNet
a addr
o sync.Once
}
func newInProcNet() *inProcNet {
return &inProcNet{listeners: make(map[addr]*inProcListener)}
}
func (n *inProcNet) Listen(network, address string) (net.Listener, error) {
a := addr{network, address}
n.Lock()
defer n.Unlock()
if _, ok := n.listeners[a]; ok {
return nil, xerrors.New("busy")
}
l := newInProcListener(n, a)
n.listeners[a] = l
return l, nil
}
func (n *inProcNet) dial(ctx context.Context, a addr) (net.Conn, error) {
n.Lock()
defer n.Unlock()
l, ok := n.listeners[a]
if !ok {
return nil, xerrors.Errorf("nothing listening on %s", a)
}
x, y := net.Pipe()
select {
case <-ctx.Done():
return nil, ctx.Err()
case l.c <- x:
return y, nil
}
}
func newInProcListener(n *inProcNet, a addr) *inProcListener {
return &inProcListener{
c: make(chan net.Conn),
n: n,
a: a,
}
}
func (l *inProcListener) Accept() (net.Conn, error) {
c, ok := <-l.c
if !ok {
return nil, net.ErrClosed
}
return c, nil
}
func (l *inProcListener) Close() error {
l.o.Do(func() {
l.n.Lock()
defer l.n.Unlock()
delete(l.n.listeners, l.a)
close(l.c)
})
return nil
}
func (l *inProcListener) Addr() net.Addr {
return l.a
}