feat: change port-forward to opportunistically listen on IPv6 (#15640)

If the local IP address is not explicitly set, previously we assumed 127.0.0.1 (that is, IPv4 only localhost). This PR adds support to opportunistically _also_ listen on IPv6 ::1.
This commit is contained in:
Spike Curtis
2024-11-25 16:33:28 +04:00
committed by GitHub
parent 1cdc3e8921
commit e6506f0679
3 changed files with 169 additions and 67 deletions

View File

@ -25,6 +25,14 @@ import (
"github.com/coder/serpent"
)
var (
// noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
// when the local address is not specified in port-forward flags.
noAddr netip.Addr
ipv6Loopback = netip.MustParseAddr("::1")
ipv4Loopback = netip.MustParseAddr("127.0.0.1")
)
func (r *RootCmd) portForward() *serpent.Command {
var (
tcpForwards []string // <port>:<port>
@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
// Start all listeners.
var (
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
listeners = make([]net.Listener, 0, len(specs)*2)
closeAllListeners = func() {
logger.Debug(ctx, "closing all listeners")
for _, l := range listeners {
@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
)
defer closeAllListeners()
for i, spec := range specs {
for _, spec := range specs {
if spec.listenHost == noAddr {
// first, opportunistically try to listen on IPv6
spec6 := spec
spec6.listenHost = ipv6Loopback
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
if err6 != nil {
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
} else {
listeners = append(listeners, l6)
}
spec.listenHost = ipv4Loopback
}
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
if err != nil {
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
return err
}
listeners[i] = l
listeners = append(listeners, l)
}
stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID)
@ -206,12 +226,19 @@ func listenAndPortForward(
spec portForwardSpec,
logger slog.Logger,
) (net.Listener, error) {
logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress))
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
logger = logger.With(
slog.F("network", spec.network),
slog.F("listen_host", spec.listenHost),
slog.F("listen_port", spec.listenPort),
)
listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort)
dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort)
_, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n",
spec.network, listenAddress, spec.network, dialAddress)
l, err := inv.Net.Listen(spec.listenNetwork, spec.listenAddress)
l, err := inv.Net.Listen(spec.network, listenAddress.String())
if err != nil {
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
return nil, xerrors.Errorf("listen '%s://%s': %w", spec.network, listenAddress.String(), err)
}
logger.Debug(ctx, "listening")
@ -226,24 +253,31 @@ func listenAndPortForward(
logger.Debug(ctx, "listener closed")
return
}
_, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
_, _ = fmt.Fprintf(inv.Stderr,
"Error accepting connection from '%s://%s': %v\n",
spec.network, listenAddress.String(), err)
_, _ = fmt.Fprintln(inv.Stderr, "Killing listener")
return
}
logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx, "accepted connection",
slog.F("remote_addr", netConn.RemoteAddr()))
go func(netConn net.Conn) {
defer netConn.Close()
remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress)
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
if err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
_, _ = fmt.Fprintf(inv.Stderr,
"Failed to dial '%s://%s' in workspace: %s\n",
spec.network, dialAddress, err)
return
}
defer remoteConn.Close()
logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx,
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
agentssh.Bicopy(ctx, netConn, remoteConn)
logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
logger.Debug(ctx,
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))
}(netConn)
}
}(spec)
@ -252,11 +286,9 @@ func listenAndPortForward(
}
type portForwardSpec struct {
listenNetwork string // tcp, udp
listenAddress string // <ip>:<port> or path
dialNetwork string // tcp, udp
dialAddress string // <ip>:<port> or path
network string // tcp, udp
listenHost netip.Addr
listenPort, dialPort uint16
}
func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
@ -264,36 +296,28 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
for _, specEntry := range tcpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
}
for _, port := range ports {
specs = append(specs, portForwardSpec{
listenNetwork: "tcp",
listenAddress: port.local.String(),
dialNetwork: "tcp",
dialAddress: port.remote.String(),
})
for _, pfSpec := range pfSpecs {
pfSpec.network = "tcp"
specs = append(specs, pfSpec)
}
}
}
for _, specEntry := range udpSpecs {
for _, spec := range strings.Split(specEntry, ",") {
ports, err := parseSrcDestPorts(strings.TrimSpace(spec))
pfSpecs, err := parseSrcDestPorts(strings.TrimSpace(spec))
if err != nil {
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
}
for _, port := range ports {
specs = append(specs, portForwardSpec{
listenNetwork: "udp",
listenAddress: port.local.String(),
dialNetwork: "udp",
dialAddress: port.remote.String(),
})
for _, pfSpec := range pfSpecs {
pfSpec.network = "udp"
specs = append(specs, pfSpec)
}
}
}
@ -301,9 +325,9 @@ func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
// Check for duplicate entries.
locals := map[string]struct{}{}
for _, spec := range specs {
localStr := fmt.Sprintf("%v:%v", spec.listenNetwork, spec.listenAddress)
localStr := fmt.Sprintf("%s:%s:%d", spec.network, spec.listenHost, spec.listenPort)
if _, ok := locals[localStr]; ok {
return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork, spec.listenAddress)
return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network, spec.listenHost, spec.listenPort)
}
locals[localStr] = struct{}{}
}
@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
return uint16(port), nil
}
type parsedSrcDestPort struct {
local, remote netip.AddrPort
}
// specRegexp matches port specs. It handles all the following formats:
//
// 8000
@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
// 9: end or remote port range
var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`)
func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
var (
err error
localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
)
func parseSrcDestPorts(in string) ([]portForwardSpec, error) {
groups := specRegexp.FindStringSubmatch(in)
if len(groups) == 0 {
return nil, xerrors.Errorf("invalid port specification %q", in)
}
var localAddr netip.Addr
if groups[2] != "" {
localAddr, err = netip.ParseAddr(strings.Trim(groups[2], "[]"))
parsedAddr, err := netip.ParseAddr(strings.Trim(groups[2], "[]"))
if err != nil {
return nil, xerrors.Errorf("invalid IP address %q", groups[2])
}
localAddr = parsedAddr
}
local, err := parsePortRange(groups[3], groups[5])
@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
if len(local) != len(remote) {
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
}
var out []parsedSrcDestPort
var out []portForwardSpec
for i := range local {
out = append(out, parsedSrcDestPort{
local: netip.AddrPortFrom(localAddr, local[i]),
remote: netip.AddrPortFrom(remoteAddr, remote[i]),
out = append(out, portForwardSpec{
listenHost: localAddr,
listenPort: local[i],
dialPort: remote[i],
})
}
return out, nil