mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
fix(agent): More protection for lost output of SSH PTY commands (#6833)
Fixes #6656 (part 2)
This commit is contained in:
committed by
GitHub
parent
164528176a
commit
891bbda995
@ -1125,13 +1125,28 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
||||
go func() {
|
||||
_, _ = io.Copy(ptty.Input(), session)
|
||||
}()
|
||||
|
||||
// In low parallelism scenarios, the command may exit and we may close
|
||||
// the pty before the output copy has started. This can result in the
|
||||
// output being lost. To avoid this, we wait for the output copy to
|
||||
// start before waiting for the command to exit. This ensures that the
|
||||
// output copy goroutine will be scheduled before calling close on the
|
||||
// pty. There is still a risk of data loss if a command produces a lot
|
||||
// of output, see TestAgent_Session_TTY_HugeOutputIsNotLost (skipped).
|
||||
outputCopyStarted := make(chan struct{})
|
||||
ptyOutput := func() io.Reader {
|
||||
defer close(outputCopyStarted)
|
||||
return ptty.Output()
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
// Ensure data is flushed to session on command exit, if we
|
||||
// close the session too soon, we might lose data.
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(session, ptty.Output())
|
||||
_, _ = io.Copy(session, ptyOutput())
|
||||
}()
|
||||
<-outputCopyStarted
|
||||
|
||||
err = process.Wait()
|
||||
var exitErr *exec.ExitError
|
||||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
|
||||
|
@ -373,9 +373,12 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
|
||||
|
||||
var stdout bytes.Buffer
|
||||
// NOTE(mafredri): Increase iterations to increase chance of failure,
|
||||
// assuming bug is present.
|
||||
// assuming bug is present. Limiting GOMAXPROCS further
|
||||
// increases the chance of failure.
|
||||
// Using 1000 iterations is basically a guaranteed failure (but let's
|
||||
// not increase test times needlessly).
|
||||
// Limit GOMAXPROCS (e.g. `export GOMAXPROCS=1`) to further increase
|
||||
// chance of failure. Also -race helps.
|
||||
for i := 0; i < 5; i++ {
|
||||
func() {
|
||||
stdout.Reset()
|
||||
@ -399,6 +402,63 @@ func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_Session_TTY_HugeOutputIsNotLost(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
// This might be our implementation, or ConPTY itself.
|
||||
// It's difficult to find extensive tests for it, so
|
||||
// it seems like it could be either.
|
||||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||||
}
|
||||
t.Skip("This test proves we have a bug where parts of large output on a PTY can be lost after the command exits, skipped to avoid test failures.")
|
||||
|
||||
// This test is here to prevent prove we have a bug where quickly executing
|
||||
// commands (with TTY) don't flush their output to the SSH session. This is
|
||||
// due to the pty being closed before all the output has been copied, but
|
||||
// protecting against this requires a non-trivial rewrite of the output
|
||||
// processing (or figuring out a way to put the pty in a mode where this
|
||||
// does not happen).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0)
|
||||
sshClient, err := conn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
ptty := ptytest.New(t)
|
||||
|
||||
var stdout bytes.Buffer
|
||||
// NOTE(mafredri): Increase iterations to increase chance of failure,
|
||||
// assuming bug is present.
|
||||
// Using 10 iterations is basically a guaranteed failure (but let's
|
||||
// not increase test times needlessly). Run with -race and do not
|
||||
// limit parallelism (`export GOMAXPROCS=10`) to increase the chance
|
||||
// of failure.
|
||||
for i := 0; i < 1; i++ {
|
||||
func() {
|
||||
stdout.Reset()
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||||
require.NoError(t, err)
|
||||
|
||||
session.Stdout = &stdout
|
||||
session.Stderr = ptty.Output()
|
||||
session.Stdin = ptty.Input()
|
||||
want := strings.Repeat("wazzup", 1024+1) // ~6KB, +1 because 1024 is a common buffer size.
|
||||
err = session.Start("echo " + want)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = session.Wait()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, stdout.String(), want, "should output entire greeting")
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:paralleltest // This test reserves a port.
|
||||
func TestAgent_TCPLocalForwarding(t *testing.T) {
|
||||
random, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
|
Reference in New Issue
Block a user