mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
fix(agent): Prevent SSH TTYs from losing command output on exit (#6777)
This commit is contained in:
committed by
GitHub
parent
d7d210de36
commit
76bdde7f1b
@ -844,6 +844,7 @@ func (a *agent) init(ctx context.Context) {
|
|||||||
_ = session.Exit(MagicSessionErrorCode)
|
_ = session.Exit(MagicSessionErrorCode)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
_ = session.Exit(0)
|
||||||
},
|
},
|
||||||
HostSigners: []ssh.Signer{randomSigner},
|
HostSigners: []ssh.Signer{randomSigner},
|
||||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||||
@ -1100,7 +1101,9 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("start command: %w", err)
|
return xerrors.Errorf("start command: %w", err)
|
||||||
}
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
defer func() {
|
defer func() {
|
||||||
|
defer wg.Wait()
|
||||||
closeErr := ptty.Close()
|
closeErr := ptty.Close()
|
||||||
if closeErr != nil {
|
if closeErr != nil {
|
||||||
a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
|
a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
|
||||||
@ -1117,10 +1120,16 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
// We don't add input copy to wait group because
|
||||||
|
// it won't return until the session is closed.
|
||||||
go func() {
|
go func() {
|
||||||
_, _ = io.Copy(ptty.Input(), session)
|
_, _ = io.Copy(ptty.Input(), session)
|
||||||
}()
|
}()
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
// Ensure data is flushed to session on command exit, if we
|
||||||
|
// close the session too soon, we might lose data.
|
||||||
|
defer wg.Done()
|
||||||
_, _ = io.Copy(session, ptty.Output())
|
_, _ = io.Copy(session, ptty.Output())
|
||||||
}()
|
}()
|
||||||
err = process.Wait()
|
err = process.Wait()
|
||||||
|
@ -348,6 +348,57 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
|
|||||||
require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
|
require.NotContains(t, stdout.String(), wantNotMOTD, "should not show motd")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAgent_Session_TTY_FastCommandHasOutput(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
// This might be our implementation, or ConPTY itself.
|
||||||
|
// It's difficult to find extensive tests for it, so
|
||||||
|
// it seems like it could be either.
|
||||||
|
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test is here to prevent regressions where quickly executing
|
||||||
|
// commands (with TTY) don't flush their output to the SSH session.
|
||||||
|
//
|
||||||
|
// See: https://github.com/coder/coder/issues/6656
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
defer cancel()
|
||||||
|
//nolint:dogsled
|
||||||
|
conn, _, _, _, _ := setupAgent(t, agentsdk.Metadata{}, 0)
|
||||||
|
sshClient, err := conn.SSHClient(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sshClient.Close()
|
||||||
|
|
||||||
|
ptty := ptytest.New(t)
|
||||||
|
|
||||||
|
var stdout bytes.Buffer
|
||||||
|
// NOTE(mafredri): Increase iterations to increase chance of failure,
|
||||||
|
// assuming bug is present.
|
||||||
|
// Using 1000 iterations is basically a guaranteed failure (but let's
|
||||||
|
// not increase test times needlessly).
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
func() {
|
||||||
|
stdout.Reset()
|
||||||
|
|
||||||
|
session, err := sshClient.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
err = session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
session.Stdout = &stdout
|
||||||
|
session.Stderr = ptty.Output()
|
||||||
|
session.Stdin = ptty.Input()
|
||||||
|
err = session.Start("echo wazzup")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = session.Wait()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, stdout.String(), "wazzup", "should output greeting")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//nolint:paralleltest // This test reserves a port.
|
//nolint:paralleltest // This test reserves a port.
|
||||||
func TestAgent_TCPLocalForwarding(t *testing.T) {
|
func TestAgent_TCPLocalForwarding(t *testing.T) {
|
||||||
random, err := net.Listen("tcp", "127.0.0.1:0")
|
random, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
Reference in New Issue
Block a user