test: add test that we close stdin on SSH session close (#18711)

closes #18519

Adds a unit test that verifies that we close the stdin to a non-TTY process when the SSH session connected to it exits.

c.f. https://github.com/coder/coder/issues/18519#issuecomment-3027609871

Validates that we match OpenSSH behavior.
This commit is contained in:
Spike Curtis
2025-07-02 16:23:07 +04:00
committed by GitHub
parent 8a69f6af17
commit 59c8b560fa
2 changed files with 80 additions and 1 deletions

View File

@ -609,7 +609,9 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
// and SSH server close may be delayed. // and SSH server close may be delayed.
cmd.SysProcAttr = cmdSysProcAttr() cmd.SysProcAttr = cmdSysProcAttr()
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends. // to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends. OpenSSH closes the
// pipes to the process when the session ends; which is what happens here since we wire the command up to the
// session for I/O.
// c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271 // c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
cmd.Cancel = nil cmd.Cancel = nil

View File

@ -8,7 +8,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"os/user" "os/user"
"path/filepath"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -403,6 +405,81 @@ func TestNewServer_Signal(t *testing.T) {
}) })
} }
func TestSSHServer_ClosesStdin(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("bash doesn't exist on Windows")
}
ctx := testutil.Context(t, testutil.WaitMedium)
logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err)
defer s.Close()
err = s.UpdateHostSigner(42)
assert.NoError(t, err)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
err := s.Serve(ln)
assert.Error(t, err) // Server is closed.
}()
defer func() {
err := s.Close()
require.NoError(t, err)
<-done
}()
c := sshClient(t, ln.Addr().String())
sess, err := c.NewSession()
require.NoError(t, err)
stdout, err := sess.StdoutPipe()
require.NoError(t, err)
stdin, err := sess.StdinPipe()
require.NoError(t, err)
defer stdin.Close()
dir := t.TempDir()
err = os.MkdirAll(dir, 0o755)
require.NoError(t, err)
filePath := filepath.Join(dir, "result.txt")
// the shell command `read` will block until data is written to stdin, or closed. It will return
// exit code 1 if it hits EOF, which is what we want to test.
cmdErrCh := make(chan error, 1)
go func() {
cmdErrCh <- sess.Start(fmt.Sprintf("echo started; read; echo \"read exit code: $?\" > %s", filePath))
}()
cmdErr := testutil.RequireReceive(ctx, t, cmdErrCh)
require.NoError(t, cmdErr)
readCh := make(chan error, 1)
go func() {
buf := make([]byte, 8)
_, err := stdout.Read(buf)
assert.Equal(t, "started\n", string(buf))
readCh <- err
}()
err = testutil.RequireReceive(ctx, t, readCh)
require.NoError(t, err)
sess.Close()
var content []byte
testutil.Eventually(ctx, t, func(_ context.Context) bool {
content, err = os.ReadFile(filePath)
return err == nil
}, testutil.IntervalFast)
require.NoError(t, err)
require.Equal(t, "read exit code: 1\n", string(content))
}
func sshClient(t *testing.T, addr string) *ssh.Client { func sshClient(t *testing.T, addr string) *ssh.Client {
conn, err := net.Dial("tcp", addr) conn, err := net.Dial("tcp", addr)
require.NoError(t, err) require.NoError(t, err)