mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
test: add test that we close stdin on SSH session close (#18711)
closes #18519 Adds a unit test that verifies that we close the stdin to a non-TTY process when the SSH session connected to it exits. c.f. https://github.com/coder/coder/issues/18519#issuecomment-3027609871 Validates that we match OpenSSH behavior.
This commit is contained in:
@ -609,7 +609,9 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
|
|||||||
// and SSH server close may be delayed.
|
// and SSH server close may be delayed.
|
||||||
cmd.SysProcAttr = cmdSysProcAttr()
|
cmd.SysProcAttr = cmdSysProcAttr()
|
||||||
|
|
||||||
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
|
// to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends. OpenSSH closes the
|
||||||
|
// pipes to the process when the session ends; which is what happens here since we wire the command up to the
|
||||||
|
// session for I/O.
|
||||||
// c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
|
// c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
|
||||||
cmd.Cancel = nil
|
cmd.Cancel = nil
|
||||||
|
|
||||||
|
@ -8,7 +8,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -403,6 +405,81 @@ func TestNewServer_Signal(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSSHServer_ClosesStdin(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("bash doesn't exist on Windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
logger := testutil.Logger(t)
|
||||||
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Close()
|
||||||
|
err = s.UpdateHostSigner(42)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
err := s.Serve(ln)
|
||||||
|
assert.Error(t, err) // Server is closed.
|
||||||
|
}()
|
||||||
|
defer func() {
|
||||||
|
err := s.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
<-done
|
||||||
|
}()
|
||||||
|
|
||||||
|
c := sshClient(t, ln.Addr().String())
|
||||||
|
|
||||||
|
sess, err := c.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stdout, err := sess.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stdin, err := sess.StdinPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer stdin.Close()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
err = os.MkdirAll(dir, 0o755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
filePath := filepath.Join(dir, "result.txt")
|
||||||
|
|
||||||
|
// the shell command `read` will block until data is written to stdin, or closed. It will return
|
||||||
|
// exit code 1 if it hits EOF, which is what we want to test.
|
||||||
|
cmdErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
cmdErrCh <- sess.Start(fmt.Sprintf("echo started; read; echo \"read exit code: $?\" > %s", filePath))
|
||||||
|
}()
|
||||||
|
|
||||||
|
cmdErr := testutil.RequireReceive(ctx, t, cmdErrCh)
|
||||||
|
require.NoError(t, cmdErr)
|
||||||
|
|
||||||
|
readCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
_, err := stdout.Read(buf)
|
||||||
|
assert.Equal(t, "started\n", string(buf))
|
||||||
|
readCh <- err
|
||||||
|
}()
|
||||||
|
err = testutil.RequireReceive(ctx, t, readCh)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sess.Close()
|
||||||
|
|
||||||
|
var content []byte
|
||||||
|
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||||
|
content, err = os.ReadFile(filePath)
|
||||||
|
return err == nil
|
||||||
|
}, testutil.IntervalFast)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "read exit code: 1\n", string(content))
|
||||||
|
}
|
||||||
|
|
||||||
func sshClient(t *testing.T, addr string) *ssh.Client {
|
func sshClient(t *testing.T, addr string) *ssh.Client {
|
||||||
conn, err := net.Dial("tcp", addr)
|
conn, err := net.Dial("tcp", addr)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
Reference in New Issue
Block a user