mirror of
https://github.com/coder/coder.git
synced 2025-07-23 21:32:07 +00:00
fix: detect and retry reverse port forward on used port (#10844)
Fixes #10799 The flake happens when we try to remote forward, but the port we've chosen is not free. In the flaked example, it's actually the SSH listener that occupies the port we try to remote forward, leading to confusing reads (c.f. the linked issue). This fix simplies the tests considerably by using the Go ssh client, rather than shelling out to OpenSSH. This avoids using a pseudoterminal, avoids the need for starting any local OS listeners to communicate the forwarding (go SSH just returns in-process listeners), and avoids an OS listener to wire OpenSSH up to the agentConn. With the simplied logic, we can immediately tell if a remote forward on a random port fails, so we can do this in a loop until success or timeout. I've also simplified and fixed up the other forwarding tests. Since we set up forwarding in-process with Go ssh, we can remove a lot of the `require.Eventually` logic.
This commit is contained in:
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -17,7 +18,6 @@ import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -25,7 +25,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
scp "github.com/bramvdbogaerde/go-scp"
|
||||
"github.com/bramvdbogaerde/go-scp"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
@ -52,7 +52,6 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
@ -648,150 +647,57 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:paralleltest // This test reserves a port.
|
||||
func TestAgent_TCPLocalForwarding(t *testing.T) {
|
||||
random, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
_ = random.Close()
|
||||
tcpAddr, valid := random.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
randomPort := tcpAddr.Port
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
local, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer local.Close()
|
||||
tcpAddr, valid = local.Addr().(*net.TCPAddr)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
conn, err := local.Accept()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
go echoOnce(t, rl)
|
||||
|
||||
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
|
||||
sshClient := setupAgentSSHClient(ctx, t)
|
||||
|
||||
go func() {
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
_, err = conn.Write([]byte("test"))
|
||||
if !assert.NoError(t, err) {
|
||||
return false
|
||||
}
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return false
|
||||
}
|
||||
if !assert.Equal(t, "test", string(b)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
|
||||
<-done
|
||||
|
||||
_ = proc.Kill()
|
||||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
//nolint:paralleltest // This test reserves a port.
|
||||
func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
||||
random, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
_ = random.Close()
|
||||
tcpAddr, valid := random.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
randomPort := tcpAddr.Port
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
sshClient := setupAgentSSHClient(ctx, t)
|
||||
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
tcpAddr, valid = l.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
localPort := tcpAddr.Port
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
conn, err := l.Accept()
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
var randomPort uint16
|
||||
var ll net.Listener
|
||||
var err error
|
||||
for {
|
||||
randomPort = pickRandomPort()
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
ll, err = sshClient.ListenTCP(addr)
|
||||
if err != nil {
|
||||
return
|
||||
t.Logf("error remote forwarding: %s", err.Error())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out getting random listener")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
break
|
||||
}
|
||||
defer ll.Close()
|
||||
go echoOnce(t, ll)
|
||||
|
||||
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
|
||||
|
||||
go func() {
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
_, err = conn.Write([]byte("test"))
|
||||
if !assert.NoError(t, err) {
|
||||
return false
|
||||
}
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return false
|
||||
}
|
||||
if !assert.Equal(t, "test", string(b)) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
|
||||
<-done
|
||||
|
||||
_ = proc.Kill()
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
@ -799,52 +705,18 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
localSocketPath := filepath.Join(tmpdir, "local-socket")
|
||||
|
||||
l, err := net.Listen("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
go echoOnce(t, l)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
sshClient := setupAgentSSHClient(ctx, t)
|
||||
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
|
||||
|
||||
go func() {
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := os.Stat(localSocketPath)
|
||||
return err == nil
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
conn, err := net.Dial("unix", localSocketPath)
|
||||
conn, err := sshClient.Dial("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
_, err = conn.Write([]byte("test"))
|
||||
@ -854,9 +726,6 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", string(b))
|
||||
_ = conn.Close()
|
||||
<-done
|
||||
|
||||
_ = proc.Kill()
|
||||
}
|
||||
|
||||
func TestAgent_UnixRemoteForwarding(t *testing.T) {
|
||||
@ -867,66 +736,19 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
|
||||
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
localSocketPath := filepath.Join(tmpdir, "local-socket")
|
||||
|
||||
l, err := net.Listen("unix", localSocketPath)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
sshClient := setupAgentSSHClient(ctx, t)
|
||||
|
||||
l, err := sshClient.ListenUnix(remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
go echoOnce(t, l)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
|
||||
|
||||
go func() {
|
||||
err := proc.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// It's possible that the socket is created but the server is not ready to
|
||||
// accept connections yet. We need to retry until we can connect.
|
||||
//
|
||||
// Note that we wait long here because if the tailnet connection has trouble
|
||||
// connecting, it could take 5 seconds or more to reconnect.
|
||||
var conn net.Conn
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
conn, err = net.Dial("unix", remoteSocketPath)
|
||||
return err == nil
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
conn, err := net.Dial("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
_, err = conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", string(b))
|
||||
_ = conn.Close()
|
||||
|
||||
<-done
|
||||
|
||||
_ = proc.Kill()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_SFTP(t *testing.T) {
|
||||
@ -2063,50 +1885,14 @@ func TestAgent_DebugServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
|
||||
//nolint:dogsled
|
||||
// setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it
|
||||
func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
|
||||
//nolint: dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
waitGroup := sync.WaitGroup{}
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
ssh, err := agentConn.SSH(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
waitGroup.Add(1)
|
||||
go func() {
|
||||
agentssh.Bicopy(context.Background(), conn, ssh)
|
||||
waitGroup.Done()
|
||||
}()
|
||||
}
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
_ = listener.Close()
|
||||
waitGroup.Wait()
|
||||
})
|
||||
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
args := append(beforeArgs,
|
||||
"-o", "HostName "+tcpAddr.IP.String(),
|
||||
"-o", "Port "+strconv.Itoa(tcpAddr.Port),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"host",
|
||||
)
|
||||
args = append(args, afterArgs...)
|
||||
cmd := pty.Command("ssh", args...)
|
||||
return ptytest.Start(t, cmd)
|
||||
t.Cleanup(func() { sshClient.Close() })
|
||||
return sshClient
|
||||
}
|
||||
|
||||
func setupSSHSession(
|
||||
@ -2580,3 +2366,47 @@ func (s *syncWriter) Write(p []byte) (int, error) {
|
||||
defer s.mu.Unlock()
|
||||
return s.w.Write(p)
|
||||
}
|
||||
|
||||
// pickRandomPort picks a random port number for the ephemeral range. We do this entirely randomly
|
||||
// instead of opening a listener and closing it to find a port that is likely to be free, since
|
||||
// sometimes the OS reallocates the port very quickly.
|
||||
func pickRandomPort() uint16 {
|
||||
const (
|
||||
// Overlap of windows, linux in https://en.wikipedia.org/wiki/Ephemeral_port
|
||||
min = 49152
|
||||
max = 60999
|
||||
)
|
||||
n := max - min
|
||||
x := rand.Intn(n) //nolint: gosec
|
||||
return uint16(min + x)
|
||||
}
|
||||
|
||||
// echoOnce accepts a single connection, reads 4 bytes and echos them back
|
||||
func echoOnce(t *testing.T, ll net.Listener) {
|
||||
t.Helper()
|
||||
conn, err := ll.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(b)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// requireEcho sends 4 bytes and requires the read response to match what was sent.
|
||||
func requireEcho(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
_, err := conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
b := make([]byte, 4)
|
||||
_, err = conn.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", string(b))
|
||||
}
|
||||
|
Reference in New Issue
Block a user