mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
feat: Set SSH env vars: SSH_CLIENT
, SSH_CONNECTION
and SSH_TTY
(#3622)
Fixes #2339
This commit is contained in:
committed by
GitHub
parent
9c0cd5287c
commit
e44f7adb7e
@ -404,6 +404,15 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
|
|||||||
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
|
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
|
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
|
||||||
|
|
||||||
|
// Set SSH connection environment variables (these are also set by OpenSSH
|
||||||
|
// and thus expected to be present by SSH clients). Since the agent does
|
||||||
|
// networking in-memory, trying to provide accurate values here would be
|
||||||
|
// nonsensical. For now, we hard code these values so that they're present.
|
||||||
|
srcAddr, srcPort := "0.0.0.0", "0"
|
||||||
|
dstAddr, dstPort := "0.0.0.0", "0"
|
||||||
|
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort))
|
||||||
|
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort))
|
||||||
|
|
||||||
// Load environment variables passed via the agent.
|
// Load environment variables passed via the agent.
|
||||||
// These should override all variables we manually specify.
|
// These should override all variables we manually specify.
|
||||||
for envKey, value := range metadata.EnvironmentVariables {
|
for envKey, value := range metadata.EnvironmentVariables {
|
||||||
@ -441,6 +450,8 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
|||||||
sshPty, windowSize, isPty := session.Pty()
|
sshPty, windowSize, isPty := session.Pty()
|
||||||
if isPty {
|
if isPty {
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
|
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
|
||||||
|
|
||||||
|
// The pty package sets `SSH_TTY` on supported platforms.
|
||||||
ptty, process, err := pty.Start(cmd)
|
ptty, process, err := pty.Start(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("start command: %w", err)
|
return xerrors.Errorf("start command: %w", err)
|
||||||
|
@ -252,6 +252,29 @@ func TestAgent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("SSH connection env vars", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Note: the SSH_TTY environment variable should only be set for TTYs.
|
||||||
|
// For some reason this test produces a TTY locally and a non-TTY in CI
|
||||||
|
// so we don't test for the absence of SSH_TTY.
|
||||||
|
for _, key := range []string{"SSH_CONNECTION", "SSH_CLIENT"} {
|
||||||
|
key := key
|
||||||
|
t.Run(key, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
session := setupSSHSession(t, agent.Metadata{})
|
||||||
|
command := "sh -c 'echo $" + key + "'"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
command = "cmd.exe /c echo %" + key + "%"
|
||||||
|
}
|
||||||
|
output, err := session.Output(command)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, strings.TrimSpace(string(output)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("StartupScript", func(t *testing.T) {
|
t.Run("StartupScript", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tempPath := filepath.Join(t.TempDir(), "content.txt")
|
tempPath := filepath.Join(t.TempDir(), "content.txt")
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package pty
|
package pty
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
@ -18,6 +19,8 @@ func startPty(cmd *exec.Cmd) (PTY, Process, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, xerrors.Errorf("open: %w", err)
|
return nil, nil, xerrors.Errorf("open: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_PTY=%s", tty.Name()))
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
Setsid: true,
|
Setsid: true,
|
||||||
Setctty: true,
|
Setctty: true,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
//go:build !windows
|
//go:build !windows
|
||||||
// +build !windows
|
|
||||||
|
|
||||||
package pty_test
|
package pty_test
|
||||||
|
|
||||||
@ -40,4 +39,12 @@ func TestStart(t *testing.T) {
|
|||||||
require.True(t, xerrors.As(err, &exitErr))
|
require.True(t, xerrors.As(err, &exitErr))
|
||||||
assert.NotEqual(t, 0, exitErr.ExitCode())
|
assert.NotEqual(t, 0, exitErr.ExitCode())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("SSH_PTY", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
pty, ps := ptytest.Start(t, exec.Command("env"))
|
||||||
|
pty.ExpectMatch("SSH_PTY=/dev/")
|
||||||
|
err := ps.Wait()
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user