mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
feat(cli): add reverse tunnelling SSH support for unix sockets (#9976)
This commit is contained in:
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
@ -23,15 +24,24 @@ type cookieAddr struct {
|
||||
|
||||
// Format:
|
||||
// remote_port:local_address:local_port
|
||||
var remoteForwardRegex = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
|
||||
var remoteForwardRegexTCP = regexp.MustCompile(`^(\d+):(.+):(\d+)$`)
|
||||
|
||||
func validateRemoteForward(flag string) bool {
|
||||
return remoteForwardRegex.MatchString(flag)
|
||||
// remote_socket_path:local_socket_path (both absolute paths)
|
||||
var remoteForwardRegexUnixSocket = regexp.MustCompile(`^(\/.+):(\/.+)$`)
|
||||
|
||||
func isRemoteForwardTCP(flag string) bool {
|
||||
return remoteForwardRegexTCP.MatchString(flag)
|
||||
}
|
||||
|
||||
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
|
||||
matches := remoteForwardRegex.FindStringSubmatch(flag)
|
||||
func isRemoteForwardUnixSocket(flag string) bool {
|
||||
return remoteForwardRegexUnixSocket.MatchString(flag)
|
||||
}
|
||||
|
||||
func validateRemoteForward(flag string) bool {
|
||||
return isRemoteForwardTCP(flag) || isRemoteForwardUnixSocket(flag)
|
||||
}
|
||||
|
||||
func parseRemoteForwardTCP(matches []string) (net.Addr, net.Addr, error) {
|
||||
remotePort, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("remote port is invalid: %w", err)
|
||||
@ -57,6 +67,46 @@ func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
func parseRemoteForwardUnixSocket(matches []string) (net.Addr, net.Addr, error) {
|
||||
remoteSocket := matches[1]
|
||||
localSocket := matches[2]
|
||||
|
||||
fileInfo, err := os.Stat(localSocket)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if fileInfo.Mode()&os.ModeSocket == 0 {
|
||||
return nil, nil, xerrors.New("File is not a Unix domain socket file")
|
||||
}
|
||||
|
||||
remoteAddr := &net.UnixAddr{
|
||||
Name: remoteSocket,
|
||||
Net: "unix",
|
||||
}
|
||||
|
||||
localAddr := &net.UnixAddr{
|
||||
Name: localSocket,
|
||||
Net: "unix",
|
||||
}
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
func parseRemoteForward(flag string) (net.Addr, net.Addr, error) {
|
||||
tcpMatches := remoteForwardRegexTCP.FindStringSubmatch(flag)
|
||||
|
||||
if len(tcpMatches) > 0 {
|
||||
return parseRemoteForwardTCP(tcpMatches)
|
||||
}
|
||||
|
||||
unixSocketMatches := remoteForwardRegexUnixSocket.FindStringSubmatch(flag)
|
||||
if len(unixSocketMatches) > 0 {
|
||||
return parseRemoteForwardUnixSocket(unixSocketMatches)
|
||||
}
|
||||
|
||||
return nil, nil, xerrors.New("Could not match forward arguments")
|
||||
}
|
||||
|
||||
// sshRemoteForward starts forwarding connections from a remote listener to a
|
||||
// local address via SSH in a goroutine.
|
||||
//
|
||||
|
@ -428,6 +428,54 @@ func TestSSH(t *testing.T) {
|
||||
<-cmdDone
|
||||
})
|
||||
|
||||
t.Run("RemoteForwardUnixSocket", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Test not supported on windows")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||
|
||||
_ = agenttest.New(t, client.URL, agentToken)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
tmpdir := tempDirUnixSocket(t)
|
||||
agentSock := filepath.Join(tmpdir, "agent.sock")
|
||||
l, err := net.Listen("unix", agentSock)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
inv, root := clitest.New(t,
|
||||
"ssh",
|
||||
workspace.Name,
|
||||
"--remote-forward",
|
||||
"/tmp/test.sock:"+agentSock,
|
||||
)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
inv.Stderr = pty.Output()
|
||||
cmdDone := tGo(t, func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.NoError(t, err, "ssh command failed")
|
||||
})
|
||||
|
||||
// Wait for the prompt or any output really to indicate the command has
|
||||
// started and accepting input on stdin.
|
||||
_ = pty.Peek(ctx, 1)
|
||||
|
||||
// Download the test page
|
||||
pty.WriteLine("ss -xl state listening src /tmp/test.sock | wc -l")
|
||||
pty.ExpectMatch("2")
|
||||
|
||||
// And we're done.
|
||||
pty.WriteLine("exit")
|
||||
<-cmdDone
|
||||
})
|
||||
|
||||
t.Run("FileLogging", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
Reference in New Issue
Block a user