mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
feat: change port-forward to opportunistically listen on IPv6 (#15640)
If the local IP address is not explicitly set, previously we assumed 127.0.0.1 (that is, IPv4 only localhost). This PR adds support to opportunistically _also_ listen on IPv6 ::1.
This commit is contained in:
@ -25,6 +25,14 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
var (
|
||||
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
|
||||
// when the local address is not specified in port-forward flags.
|
||||
noAddr netip.Addr
|
||||
ipv6Loopback = netip.MustParseAddr("::1")
|
||||
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
|
||||
)
|
||||
|
||||
func (r *RootCmd) portForward() *serpent.Command {
|
||||
var (
|
||||
tcpForwards []string // <port>:<port>
|
||||
@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
|
||||
// Start all listeners.
|
||||
var (
|
||||
wg = new(sync.WaitGroup)
|
||||
listeners = make([]net.Listener, len(specs))
|
||||
listeners = make([]net.Listener, 0, len(specs)*2)
|
||||
closeAllListeners = func() {
|
||||
logger.Debug(ctx, "closing all listeners")
|
||||
for _, l := range listeners {
|
||||
@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
|
||||
)
|
||||
defer closeAllListeners()
|
||||
|
||||
for i, spec := range specs {
|
||||
for _, spec := range specs {
|
||||
if spec.listenHost == noAddr {
|
||||
// first, opportunistically try to listen on IPv6
|
||||
spec6 := spec
|
||||
spec6.listenHost = ipv6Loopback
|
||||
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
|
||||
if err6 != nil {
|
||||
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
|
||||
} else {
|
||||
listeners = append(listeners, l6)
|
||||
}
|
||||
spec.listenHost = ipv4Loopback
|
||||
}
|
||||
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
|
||||
return err
|
||||
}
|
||||
listeners[i] = l
|
||||
listeners = append(listeners, l)
|
||||
}
|
||||
|
||||
stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
|
||||
@ -206,12 +226,19 @@ func listenAndPortForward(
|
||||
spec portForwardSpec,
|
||||
logger slog.Logger,
|
||||
) (net.Listener, error) {
|
||||
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
|
||||
logger = logger.With(
|
||||
slog.F("network", spec.network),
|
||||
slog.F("listen_host", spec.listenHost),
|
||||
slog.F("listen_port", spec.listenPort),
|
||||
)
|
||||
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
|
||||
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
|
||||
spec.network, listenAddress, spec.network, dialAddress)
|
||||
|
||||
l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
|
||||
l, err := inv.Net.Listen(spec.network, listenAddress.String())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
|
||||
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
|
||||
}
|
||||
logger.Debug(ctx, "listening")
|
||||
|
||||
@ -226,24 +253,31 @@ func listenAndPortForward(
|
||||
logger.Debug(ctx, "listener closed")
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
|
||||
_, _ = fmt.Fprintf(inv.Stderr,
|
||||
"Error accepting connection from '%s://%s': %v\n",
|
||||
spec.network, listenAddress.String(), err)
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
|
||||
return
|
||||
}
|
||||
logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr()))
|
||||
logger.Debug(ctx, "accepted connection",
|
||||
slog.F("remote_addr", netConn.RemoteAddr()))
|
||||
|
||||
go func(netConn net.Conn) {
|
||||
defer netConn.Close()
|
||||
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
|
||||
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
|
||||
_, _ = fmt.Fprintf(inv.Stderr,
|
||||
"Failed to dial '%s://%s' in workspace: %s\n",
|
||||
spec.network, dialAddress, err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
|
||||
logger.Debug(ctx,
|
||||
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
|
||||
|
||||
agentssh.Bicopy(ctx, netConn, remoteConn)
|
||||
logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
|
||||
logger.Debug(ctx,
|
||||
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
|
||||
}(netConn)
|
||||
}
|
||||
}(spec)
|
||||
@ -252,11 +286,9 @@ func listenAndPortForward(
|
||||
}
|
||||
|
||||
type portForwardSpec struct {
|
||||
listenNetwork string // tcp, udp
|
||||
listenAddress string // <ip>:<port> or path
|
||||
|
||||
dialNetwork string // tcp, udp
|
||||
dialAddress string // <ip>:<port> or path
|
||||
network string // tcp, udp
|
||||
listenHost netip.Addr
|
||||
listenPort, dialPort uint16
|
||||
}
|
||||
|
||||
func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
|
||||
@ -264,36 +296,28 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
|
||||
|
||||
for _, specEntry := range tcpSpecs {
|
||||
for _, spec := range strings.Split(specEntry, ",") {
|
||||
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
|
||||
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
for _, port := range ports {
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "tcp",
|
||||
listenAddress: port.local.String(),
|
||||
dialNetwork: "tcp",
|
||||
dialAddress: port.remote.String(),
|
||||
})
|
||||
for _, pfSpec := range pfSpecs {
|
||||
pfSpec.network = "tcp"
|
||||
specs = append(specs, pfSpec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, specEntry := range udpSpecs {
|
||||
for _, spec := range strings.Split(specEntry, ",") {
|
||||
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
|
||||
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
for _, port := range ports {
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "udp",
|
||||
listenAddress: port.local.String(),
|
||||
dialNetwork: "udp",
|
||||
dialAddress: port.remote.String(),
|
||||
})
|
||||
for _, pfSpec := range pfSpecs {
|
||||
pfSpec.network = "udp"
|
||||
specs = append(specs, pfSpec)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -301,9 +325,9 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
|
||||
// Check for duplicate entries.
|
||||
locals := map[string]struct{}{}
|
||||
for _, spec := range specs {
|
||||
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
|
||||
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
|
||||
if _, ok := locals[localStr]; ok {
|
||||
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
|
||||
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
|
||||
}
|
||||
locals[localStr] = struct{}{}
|
||||
}
|
||||
@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
type parsedSrcDestPort struct {
|
||||
local, remote netip.AddrPort
|
||||
}
|
||||
|
||||
// specRegexp matches port specs. It handles all the following formats:
|
||||
//
|
||||
// 8000
|
||||
@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
|
||||
// 9: end or remote port range
|
||||
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)
|
||||
|
||||
func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
|
||||
var (
|
||||
err error
|
||||
localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
)
|
||||
func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
|
||||
groups := specRegexp.FindStringSubmatch(in)
|
||||
if len(groups) == 0 {
|
||||
return nil, xerrors.Errorf("invalid port specification %q", in)
|
||||
}
|
||||
|
||||
var localAddr netip.Addr
|
||||
if groups[2] != "" {
|
||||
localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]"))
|
||||
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
|
||||
}
|
||||
localAddr = parsedAddr
|
||||
}
|
||||
|
||||
local, err := parsePortRange(groups[3], groups[5])
|
||||
@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
|
||||
if len(local) != len(remote) {
|
||||
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
|
||||
}
|
||||
var out []parsedSrcDestPort
|
||||
var out []portForwardSpec
|
||||
for i := range local {
|
||||
out = append(out, parsedSrcDestPort{
|
||||
local: netip.AddrPortFrom(localAddr, local[i]),
|
||||
remote: netip.AddrPortFrom(remoteAddr, remote[i]),
|
||||
out = append(out, portForwardSpec{
|
||||
listenHost: localAddr,
|
||||
listenPort: local[i],
|
||||
dialPort: remote[i],
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
|
@ -29,15 +29,15 @@ func Test_parsePortForwards(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: []portForwardSpec{
|
||||
{"tcp", "127.0.0.1:8000", "tcp", "127.0.0.1:8000"},
|
||||
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
|
||||
{"tcp", "127.0.0.1:9000", "tcp", "127.0.0.1:9000"},
|
||||
{"tcp", "127.0.0.1:9001", "tcp", "127.0.0.1:9001"},
|
||||
{"tcp", "127.0.0.1:9002", "tcp", "127.0.0.1:9002"},
|
||||
{"tcp", "127.0.0.1:9003", "tcp", "127.0.0.1:9005"},
|
||||
{"tcp", "127.0.0.1:9004", "tcp", "127.0.0.1:9006"},
|
||||
{"tcp", "127.0.0.1:10000", "tcp", "127.0.0.1:10000"},
|
||||
{"tcp", "127.0.0.1:4444", "tcp", "127.0.0.1:4444"},
|
||||
{"tcp", noAddr, 8000, 8000},
|
||||
{"tcp", noAddr, 8080, 8081},
|
||||
{"tcp", noAddr, 9000, 9000},
|
||||
{"tcp", noAddr, 9001, 9001},
|
||||
{"tcp", noAddr, 9002, 9002},
|
||||
{"tcp", noAddr, 9003, 9005},
|
||||
{"tcp", noAddr, 9004, 9006},
|
||||
{"tcp", noAddr, 10000, 10000},
|
||||
{"tcp", noAddr, 4444, 4444},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -46,7 +46,7 @@ func Test_parsePortForwards(t *testing.T) {
|
||||
tcpSpecs: []string{"127.0.0.1:8080:8081"},
|
||||
},
|
||||
want: []portForwardSpec{
|
||||
{"tcp", "127.0.0.1:8080", "tcp", "127.0.0.1:8081"},
|
||||
{"tcp", ipv4Loopback, 8080, 8081},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -55,7 +55,7 @@ func Test_parsePortForwards(t *testing.T) {
|
||||
tcpSpecs: []string{"[::1]:8080:8081"},
|
||||
},
|
||||
want: []portForwardSpec{
|
||||
{"tcp", "[::1]:8080", "tcp", "127.0.0.1:8081"},
|
||||
{"tcp", ipv6Loopback, 8080, 8081},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -64,9 +64,9 @@ func Test_parsePortForwards(t *testing.T) {
|
||||
udpSpecs: []string{"8000,8080-8081"},
|
||||
},
|
||||
want: []portForwardSpec{
|
||||
{"udp", "127.0.0.1:8000", "udp", "127.0.0.1:8000"},
|
||||
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8080"},
|
||||
{"udp", "127.0.0.1:8081", "udp", "127.0.0.1:8081"},
|
||||
{"udp", noAddr, 8000, 8000},
|
||||
{"udp", noAddr, 8080, 8080},
|
||||
{"udp", noAddr, 8081, 8081},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -75,7 +75,7 @@ func Test_parsePortForwards(t *testing.T) {
|
||||
udpSpecs: []string{"127.0.0.1:8080:8081"},
|
||||
},
|
||||
want: []portForwardSpec{
|
||||
{"udp", "127.0.0.1:8080", "udp", "127.0.0.1:8081"},
|
||||
{"udp", ipv4Loopback, 8080, 8081},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -84,7 +84,7 @@ func Test_parsePortForwards(t *testing.T) {
|
||||
udpSpecs: []string{"[::1]:8080:8081"},
|
||||
},
|
||||
want: []portForwardSpec{
|
||||
{"udp", "[::1]:8080", "udp", "127.0.0.1:8081"},
|
||||
{"udp", ipv6Loopback, 8080, 8081},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -67,6 +67,17 @@ func TestPortForward(t *testing.T) {
|
||||
},
|
||||
localAddress: []string{"127.0.0.1:5555", "127.0.0.1:6666"},
|
||||
},
|
||||
{
|
||||
name: "TCP-opportunistic-ipv6",
|
||||
network: "tcp",
|
||||
flag: []string{"--tcp=5566:%v", "--tcp=6655:%v"},
|
||||
setupRemote: func(t *testing.T) net.Listener {
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "create TCP listener")
|
||||
return l
|
||||
},
|
||||
localAddress: []string{"[::1]:5566", "[::1]:6655"},
|
||||
},
|
||||
{
|
||||
name: "UDP",
|
||||
network: "udp",
|
||||
@ -82,6 +93,21 @@ func TestPortForward(t *testing.T) {
|
||||
},
|
||||
localAddress: []string{"127.0.0.1:7777", "127.0.0.1:8888"},
|
||||
},
|
||||
{
|
||||
name: "UDP-opportunistic-ipv6",
|
||||
network: "udp",
|
||||
flag: []string{"--udp=7788:%v", "--udp=8877:%v"},
|
||||
setupRemote: func(t *testing.T) net.Listener {
|
||||
addr := net.UDPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
}
|
||||
l, err := udp.Listen("udp", &addr)
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
localAddress: []string{"[::1]:7788", "[::1]:8877"},
|
||||
},
|
||||
{
|
||||
name: "TCPWithAddress",
|
||||
network: "tcp", flag: []string{"--tcp=10.10.10.99:9999:%v", "--tcp=10.10.10.10:1010:%v"},
|
||||
@ -295,6 +321,63 @@ func TestPortForward(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, updated.LastUsedAt, workspace.LastUsedAt)
|
||||
})
|
||||
|
||||
t.Run("IPv6Busy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
remoteLis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "create TCP listener")
|
||||
p1 := setupTestListener(t, remoteLis)
|
||||
|
||||
// Create a flag that forwards from local 5555 to remote listener port.
|
||||
flag := fmt.Sprintf("--tcp=5555:%v", p1)
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listener.
|
||||
inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag)
|
||||
clitest.SetupConfig(t, member, root)
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
inv.Stderr = pty.Output()
|
||||
|
||||
iNet := newInProcNet()
|
||||
inv.Net = iNet
|
||||
|
||||
// listen on port 5555 on IPv6 so it's busy when we try to port forward
|
||||
busyLis, err := iNet.Listen("tcp", "[::1]:5555")
|
||||
require.NoError(t, err)
|
||||
defer busyLis.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
t.Logf("command complete; err=%s", err.Error())
|
||||
errC <- err
|
||||
}()
|
||||
pty.ExpectMatchContext(ctx, "Ready!")
|
||||
|
||||
// Test IPv4 still works
|
||||
dialCtx, dialCtxCancel := context.WithTimeout(ctx, testutil.WaitShort)
|
||||
defer dialCtxCancel()
|
||||
c1, err := iNet.dial(dialCtx, addr{"tcp", "127.0.0.1:5555"})
|
||||
require.NoError(t, err, "open connection 1 to 'local' listener")
|
||||
defer c1.Close()
|
||||
testDial(t, c1)
|
||||
|
||||
cancel()
|
||||
err = <-errC
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
|
||||
flushCtx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.RequireSendCtx(flushCtx, t, wuTick, dbtime.Now())
|
||||
_ = testutil.RequireRecvCtx(flushCtx, t, wuFlush)
|
||||
updated, err := client.Workspace(context.Background(), workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, updated.LastUsedAt, workspace.LastUsedAt)
|
||||
})
|
||||
}
|
||||
|
||||
// runAgent creates a fake workspace and starts an agent locally for that
|
||||
|
Reference in New Issue
Block a user