mirror of
https://github.com/coder/coder.git
synced 2025-03-14 10:09:57 +00:00
148 lines
3.8 KiB
Go
148 lines
3.8 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"regexp"
|
|
"strconv"
|
|
|
|
gossh "golang.org/x/crypto/ssh"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/agent/agentssh"
|
|
)
|
|
|
|
// cookieAddr is a special net.Addr accepted by sshRemoteForward() which includes a
|
|
// cookie which is written to the connection before forwarding.
|
|
type cookieAddr struct {
|
|
net.Addr
|
|
cookie []byte
|
|
}
|
|
|
|
// Format:
|
|
// remote_port:local_address:local_port
|
|
var remoteForwardRegexTCP = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
|
|
|
|
// remote_socket_path:local_socket_path (both absolute paths)
|
|
var remoteForwardRegexUnixSocket = regexp.MustCompile(`^(\/.+):(\/.+)$`)
|
|
|
|
func isRemoteForwardTCP(flag string) bool {
|
|
return remoteForwardRegexTCP.MatchString(flag)
|
|
}
|
|
|
|
func isRemoteForwardUnixSocket(flag string) bool {
|
|
return remoteForwardRegexUnixSocket.MatchString(flag)
|
|
}
|
|
|
|
func validateRemoteForward(flag string) bool {
|
|
return isRemoteForwardTCP(flag) || isRemoteForwardUnixSocket(flag)
|
|
}
|
|
|
|
func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
|
|
remotePort, err := strconv.Atoi(matches[1])
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
|
|
}
|
|
localAddress, err := net.ResolveIPAddr("ip", matches[2])
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("local address is invalid: %w", err)
|
|
}
|
|
localPort, err := strconv.Atoi(matches[3])
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("local port is invalid: %w", err)
|
|
}
|
|
|
|
localAddr := &net.TCPAddr{
|
|
IP: localAddress.IP,
|
|
Port: localPort,
|
|
}
|
|
|
|
remoteAddr := &net.TCPAddr{
|
|
IP: net.ParseIP("127.0.0.1"),
|
|
Port: remotePort,
|
|
}
|
|
return localAddr, remoteAddr, nil
|
|
}
|
|
|
|
// parseRemoteForwardUnixSocket parses a remote forward flag. Note that
|
|
// we don't verify that the local socket path exists because the user
|
|
// may create it later. This behavior matches OpenSSH.
|
|
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
|
|
remoteSocket := matches[1]
|
|
localSocket := matches[2]
|
|
|
|
remoteAddr := &net.UnixAddr{
|
|
Name: remoteSocket,
|
|
Net: "unix",
|
|
}
|
|
|
|
localAddr := &net.UnixAddr{
|
|
Name: localSocket,
|
|
Net: "unix",
|
|
}
|
|
return localAddr, remoteAddr, nil
|
|
}
|
|
|
|
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
|
|
tcpMatches := remoteForwardRegexTCP.FindStringSubmatch(flag)
|
|
|
|
if len(tcpMatches) > 0 {
|
|
return parseRemoteForwardTCP(tcpMatches)
|
|
}
|
|
|
|
unixSocketMatches := remoteForwardRegexUnixSocket.FindStringSubmatch(flag)
|
|
if len(unixSocketMatches) > 0 {
|
|
return parseRemoteForwardUnixSocket(unixSocketMatches)
|
|
}
|
|
|
|
return nil, nil, xerrors.New("Could not match forward arguments")
|
|
}
|
|
|
|
// sshRemoteForward starts forwarding connections from a remote listener to a
|
|
// local address via SSH in a goroutine.
|
|
//
|
|
// Accepts a `cookieAddr` as the local address.
|
|
func sshRemoteForward(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) {
|
|
listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err)
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
remoteConn, err := listener.Accept()
|
|
if err != nil {
|
|
if ctx.Err() == nil {
|
|
_, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
defer remoteConn.Close()
|
|
|
|
localConn, err := net.Dial(localAddr.Network(), localAddr.String())
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err)
|
|
return
|
|
}
|
|
defer localConn.Close()
|
|
|
|
if c, ok := localAddr.(cookieAddr); ok {
|
|
_, err = localConn.Write(c.cookie)
|
|
if err != nil {
|
|
_, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
agentssh.Bicopy(ctx, localConn, remoteConn)
|
|
}()
|
|
}
|
|
}()
|
|
|
|
return listener, nil
|
|
}
|