mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
test: add test that we close stdin on SSH session close (#18711)
closes #18519 Adds a unit test that verifies that we close the stdin to a non-TTY process when the SSH session connected to it exits. c.f. https://github.com/coder/coder/issues/18519#issuecomment-3027609871 Validates that we match OpenSSH behavior.
This commit is contained in:
@ -609,7 +609,9 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
|
||||
// and SSH server close may be delayed.
|
||||
cmd.SysProcAttr = cmdSysProcAttr()
|
||||
|
||||
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
|
||||
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends. OpenSSH closes the
|
||||
// pipes to the process when the session ends; which is what happens here since we wire the command up to the
|
||||
// session for I/O.
|
||||
// c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
|
||||
cmd.Cancel = nil
|
||||
|
||||
|
@ -8,7 +8,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -403,6 +405,81 @@ func TestNewServer_Signal(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHServer_ClosesStdin(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("bash doesn't exist on Windows")
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
logger := testutil.Logger(t)
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
err = s.UpdateHostSigner(42)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err := s.Serve(ln)
|
||||
assert.Error(t, err) // Server is closed.
|
||||
}()
|
||||
defer func() {
|
||||
err := s.Close()
|
||||
require.NoError(t, err)
|
||||
<-done
|
||||
}()
|
||||
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
|
||||
sess, err := c.NewSession()
|
||||
require.NoError(t, err)
|
||||
stdout, err := sess.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
stdin, err := sess.StdinPipe()
|
||||
require.NoError(t, err)
|
||||
defer stdin.Close()
|
||||
|
||||
dir := t.TempDir()
|
||||
err = os.MkdirAll(dir, 0o755)
|
||||
require.NoError(t, err)
|
||||
filePath := filepath.Join(dir, "result.txt")
|
||||
|
||||
// the shell command `read` will block until data is written to stdin, or closed. It will return
|
||||
// exit code 1 if it hits EOF, which is what we want to test.
|
||||
cmdErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
cmdErrCh <- sess.Start(fmt.Sprintf("echo started; read; echo \"read exit code: $?\" > %s", filePath))
|
||||
}()
|
||||
|
||||
cmdErr := testutil.RequireReceive(ctx, t, cmdErrCh)
|
||||
require.NoError(t, cmdErr)
|
||||
|
||||
readCh := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 8)
|
||||
_, err := stdout.Read(buf)
|
||||
assert.Equal(t, "started\n", string(buf))
|
||||
readCh <- err
|
||||
}()
|
||||
err = testutil.RequireReceive(ctx, t, readCh)
|
||||
require.NoError(t, err)
|
||||
|
||||
sess.Close()
|
||||
|
||||
var content []byte
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
content, err = os.ReadFile(filePath)
|
||||
return err == nil
|
||||
}, testutil.IntervalFast)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "read exit code: 1\n", string(content))
|
||||
}
|
||||
|
||||
func sshClient(t *testing.T, addr string) *ssh.Client {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
Reference in New Issue
Block a user