fix: detect and retry reverse port forward on used port (#10844)

Fixes #10799

The flake happens when we try to remote forward, but the port we've chosen is not free.  In the flaked example, it's actually the SSH listener that occupies the port we try to remote forward, leading to confusing reads (c.f. the linked issue).

This fix simplies the tests considerably by using the Go ssh client, rather than shelling out to OpenSSH.  This avoids using a pseudoterminal, avoids the need for starting any local OS listeners to communicate the forwarding (go SSH just returns in-process listeners), and avoids an OS listener to wire OpenSSH up to the agentConn.

With the simplied logic, we can immediately tell if a remote forward on a random port fails, so we can do this in a loop until success or timeout.

I've also simplified and fixed up the other forwarding tests. Since we set up forwarding in-process with Go ssh, we can remove a lot of the `require.Eventually` logic.
This commit is contained in:
Spike Curtis
2023-11-27 09:42:45 +04:00
committed by GitHub
parent d5ddcbdda0
commit 6c67add2d9
2 changed files with 102 additions and 272 deletions

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httptest"
@ -17,7 +18,6 @@ import (
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
@ -25,7 +25,7 @@ import (
"testing"
"time"
scp "github.com/bramvdbogaerde/go-scp"
"github.com/bramvdbogaerde/go-scp"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/pion/udp"
@ -52,7 +52,6 @@ import (
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
@ -648,150 +647,57 @@ func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
}
}
//nolint:paralleltest // This test reserves a port.
func TestAgent_TCPLocalForwarding(t *testing.T) {
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
_ = random.Close()
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
randomPort := tcpAddr.Port
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
local, err := net.Listen("tcp", "127.0.0.1:0")
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer local.Close()
tcpAddr, valid = local.Addr().(*net.TCPAddr)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
done := make(chan struct{})
go func() {
defer close(done)
conn, err := local.Accept()
if !assert.NoError(t, err) {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
go echoOnce(t, rl)
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "5"})
sshClient := setupAgentSSHClient(ctx, t)
go func() {
err := proc.Wait()
select {
case <-done:
default:
assert.NoError(t, err)
}
}()
require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort))
if err != nil {
return false
}
defer conn.Close()
_, err = conn.Write([]byte("test"))
if !assert.NoError(t, err) {
return false
}
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return false
}
if !assert.Equal(t, "test", string(b)) {
return false
}
return true
}, testutil.WaitLong, testutil.IntervalSlow)
<-done
_ = proc.Kill()
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
//nolint:paralleltest // This test reserves a port.
func TestAgent_TCPRemoteForwarding(t *testing.T) {
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
_ = random.Close()
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
randomPort := tcpAddr.Port
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
sshClient := setupAgentSSHClient(ctx, t)
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer l.Close()
tcpAddr, valid = l.Addr().(*net.TCPAddr)
require.True(t, valid)
localPort := tcpAddr.Port
done := make(chan struct{})
go func() {
defer close(done)
conn, err := l.Accept()
localhost := netip.MustParseAddr("127.0.0.1")
var randomPort uint16
var ll net.Listener
var err error
for {
randomPort = pickRandomPort()
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
ll, err = sshClient.ListenTCP(addr)
if err != nil {
return
t.Logf("error remote forwarding: %s", err.Error())
select {
case <-ctx.Done():
t.Fatal("timed out getting random listener")
default:
continue
}
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
break
}
defer ll.Close()
go echoOnce(t, ll)
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "5"})
go func() {
err := proc.Wait()
select {
case <-done:
default:
assert.NoError(t, err)
}
}()
require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
if err != nil {
return false
}
defer conn.Close()
_, err = conn.Write([]byte("test"))
if !assert.NoError(t, err) {
return false
}
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return false
}
if !assert.Equal(t, "test", string(b)) {
return false
}
return true
}, testutil.WaitLong, testutil.IntervalSlow)
<-done
_ = proc.Kill()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
func TestAgent_UnixLocalForwarding(t *testing.T) {
@ -799,52 +705,18 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := tempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
localSocketPath := filepath.Join(tmpdir, "local-socket")
l, err := net.Listen("unix", remoteSocketPath)
require.NoError(t, err)
defer l.Close()
go echoOnce(t, l)
done := make(chan struct{})
go func() {
defer close(done)
sshClient := setupAgentSSHClient(ctx, t)
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
_, proc := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "5"})
go func() {
err := proc.Wait()
select {
case <-done:
default:
assert.NoError(t, err)
}
}()
require.Eventually(t, func() bool {
_, err := os.Stat(localSocketPath)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
conn, err := net.Dial("unix", localSocketPath)
conn, err := sshClient.Dial("unix", remoteSocketPath)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("test"))
@ -854,9 +726,6 @@ func TestAgent_UnixLocalForwarding(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "test", string(b))
_ = conn.Close()
<-done
_ = proc.Kill()
}
func TestAgent_UnixRemoteForwarding(t *testing.T) {
@ -867,66 +736,19 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) {
tmpdir := tempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
localSocketPath := filepath.Join(tmpdir, "local-socket")
l, err := net.Listen("unix", localSocketPath)
ctx := testutil.Context(t, testutil.WaitLong)
sshClient := setupAgentSSHClient(ctx, t)
l, err := sshClient.ListenUnix(remoteSocketPath)
require.NoError(t, err)
defer l.Close()
go echoOnce(t, l)
done := make(chan struct{})
go func() {
defer close(done)
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()
_, proc := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "5"})
go func() {
err := proc.Wait()
select {
case <-done:
default:
assert.NoError(t, err)
}
}()
// It's possible that the socket is created but the server is not ready to
// accept connections yet. We need to retry until we can connect.
//
// Note that we wait long here because if the tailnet connection has trouble
// connecting, it could take 5 seconds or more to reconnect.
var conn net.Conn
require.Eventually(t, func() bool {
var err error
conn, err = net.Dial("unix", remoteSocketPath)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
conn, err := net.Dial("unix", remoteSocketPath)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
_ = conn.Close()
<-done
_ = proc.Kill()
requireEcho(t, conn)
}
func TestAgent_SFTP(t *testing.T) {
@ -2063,50 +1885,14 @@ func TestAgent_DebugServer(t *testing.T) {
})
}
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) (*ptytest.PTYCmd, pty.Process) {
//nolint:dogsled
// setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it
func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
//nolint: dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
listener, err := net.Listen("tcp", "127.0.0.1:0")
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
waitGroup := sync.WaitGroup{}
go func() {
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
ssh, err := agentConn.SSH(ctx)
cancel()
if err != nil {
_ = conn.Close()
return
}
waitGroup.Add(1)
go func() {
agentssh.Bicopy(context.Background(), conn, ssh)
waitGroup.Done()
}()
}
}()
t.Cleanup(func() {
_ = listener.Close()
waitGroup.Wait()
})
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
require.True(t, valid)
args := append(beforeArgs,
"-o", "HostName "+tcpAddr.IP.String(),
"-o", "Port "+strconv.Itoa(tcpAddr.Port),
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"host",
)
args = append(args, afterArgs...)
cmd := pty.Command("ssh", args...)
return ptytest.Start(t, cmd)
t.Cleanup(func() { sshClient.Close() })
return sshClient
}
func setupSSHSession(
@ -2580,3 +2366,47 @@ func (s *syncWriter) Write(p []byte) (int, error) {
defer s.mu.Unlock()
return s.w.Write(p)
}
// pickRandomPort picks a random port number for the ephemeral range. We do this entirely randomly
// instead of opening a listener and closing it to find a port that is likely to be free, since
// sometimes the OS reallocates the port very quickly.
func pickRandomPort() uint16 {
const (
// Overlap of windows, linux in https://en.wikipedia.org/wiki/Ephemeral_port
min = 49152
max = 60999
)
n := max - min
x := rand.Intn(n) //nolint: gosec
return uint16(min + x)
}
// echoOnce accepts a single connection, reads 4 bytes and echos them back
func echoOnce(t *testing.T, ll net.Listener) {
t.Helper()
conn, err := ll.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}
// requireEcho sends 4 bytes and requires the read response to match what was sent.
func requireEcho(t *testing.T, conn net.Conn) {
t.Helper()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
}