mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
fix(agent): ensure SSH server shutdown with process groups (#17227)
Fix hanging workspace shutdowns caused by orphaned SSH child processes. Key changes: - Create process groups for non-PTY SSH sessions - Send SIGHUP to entire process group for proper termination - Add 5-second timeout to prevent indefinite blocking Fixes #17108
This commit is contained in:
committed by
GitHub
parent
b60934b180
commit
b61f0ab958
@ -1773,15 +1773,22 @@ func (a *agent) Close() error {
|
|||||||
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)
|
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)
|
||||||
|
|
||||||
// Attempt to gracefully shut down all active SSH connections and
|
// Attempt to gracefully shut down all active SSH connections and
|
||||||
// stop accepting new ones.
|
// stop accepting new ones. If all processes have not exited after 5
|
||||||
err := a.sshServer.Shutdown(a.hardCtx)
|
// seconds, we just log it and move on as it's more important to run
|
||||||
|
// the shutdown scripts. A typical shutdown time for containers is
|
||||||
|
// 10 seconds, so this still leaves a bit of time to run the
|
||||||
|
// shutdown scripts in the worst-case.
|
||||||
|
sshShutdownCtx, sshShutdownCancel := context.WithTimeout(a.hardCtx, 5*time.Second)
|
||||||
|
defer sshShutdownCancel()
|
||||||
|
err := a.sshServer.Shutdown(sshShutdownCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err))
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
}
|
a.logger.Warn(sshShutdownCtx, "ssh server shutdown timeout", slog.Error(err))
|
||||||
err = a.sshServer.Close()
|
} else {
|
||||||
if err != nil {
|
a.logger.Error(sshShutdownCtx, "ssh server shutdown", slog.Error(err))
|
||||||
a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err))
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for SSH to shut down before the general graceful cancel, because
|
// wait for SSH to shut down before the general graceful cancel, because
|
||||||
// this triggers a disconnect in the tailnet layer, telling all clients to
|
// this triggers a disconnect in the tailnet layer, telling all clients to
|
||||||
// shut down their wireguard tunnels to us. If SSH sessions are still up,
|
// shut down their wireguard tunnels to us. If SSH sessions are still up,
|
||||||
|
@ -582,6 +582,12 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str
|
|||||||
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
|
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
|
||||||
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
|
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
|
||||||
|
|
||||||
|
// Create a process group and send SIGHUP to child processes,
|
||||||
|
// otherwise context cancellation will not propagate properly
|
||||||
|
// and SSH server close may be delayed.
|
||||||
|
cmd.SysProcAttr = cmdSysProcAttr()
|
||||||
|
cmd.Cancel = cmdCancel(session.Context(), logger, cmd)
|
||||||
|
|
||||||
cmd.Stdout = session
|
cmd.Stdout = session
|
||||||
cmd.Stderr = session.Stderr()
|
cmd.Stderr = session.Stderr()
|
||||||
// This blocks forever until stdin is received if we don't
|
// This blocks forever until stdin is received if we don't
|
||||||
@ -926,7 +932,12 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
|
|||||||
// Serve starts the server to handle incoming connections on the provided listener.
|
// Serve starts the server to handle incoming connections on the provided listener.
|
||||||
// It returns an error if no host keys are set or if there is an issue accepting connections.
|
// It returns an error if no host keys are set or if there is an issue accepting connections.
|
||||||
func (s *Server) Serve(l net.Listener) (retErr error) {
|
func (s *Server) Serve(l net.Listener) (retErr error) {
|
||||||
if len(s.srv.HostSigners) == 0 {
|
// Ensure we're not mutating HostSigners as we're reading it.
|
||||||
|
s.mu.RLock()
|
||||||
|
noHostKeys := len(s.srv.HostSigners) == 0
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if noHostKeys {
|
||||||
return xerrors.New("no host keys set")
|
return xerrors.New("no host keys set")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1054,27 +1065,36 @@ func (s *Server) Close() error {
|
|||||||
}
|
}
|
||||||
s.closing = make(chan struct{})
|
s.closing = make(chan struct{})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
s.logger.Debug(ctx, "closing server")
|
||||||
|
|
||||||
|
// Stop accepting new connections.
|
||||||
|
s.logger.Debug(ctx, "closing all active listeners", slog.F("count", len(s.listeners)))
|
||||||
|
for l := range s.listeners {
|
||||||
|
_ = l.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// Close all active sessions to gracefully
|
// Close all active sessions to gracefully
|
||||||
// terminate client connections.
|
// terminate client connections.
|
||||||
|
s.logger.Debug(ctx, "closing all active sessions", slog.F("count", len(s.sessions)))
|
||||||
for ss := range s.sessions {
|
for ss := range s.sessions {
|
||||||
// We call Close on the underlying channel here because we don't
|
// We call Close on the underlying channel here because we don't
|
||||||
// want to send an exit status to the client (via Exit()).
|
// want to send an exit status to the client (via Exit()).
|
||||||
// Typically OpenSSH clients will return 255 as the exit status.
|
// Typically OpenSSH clients will return 255 as the exit status.
|
||||||
_ = ss.Close()
|
_ = ss.Close()
|
||||||
}
|
}
|
||||||
|
s.logger.Debug(ctx, "closing all active connections", slog.F("count", len(s.conns)))
|
||||||
// Close all active listeners and connections.
|
|
||||||
for l := range s.listeners {
|
|
||||||
_ = l.Close()
|
|
||||||
}
|
|
||||||
for c := range s.conns {
|
for c := range s.conns {
|
||||||
_ = c.Close()
|
_ = c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the underlying SSH server.
|
s.logger.Debug(ctx, "closing SSH server")
|
||||||
err := s.srv.Close()
|
err := s.srv.Close()
|
||||||
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
s.logger.Debug(ctx, "waiting for all goroutines to exit")
|
||||||
s.wg.Wait() // Wait for all goroutines to exit.
|
s.wg.Wait() // Wait for all goroutines to exit.
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@ -1082,15 +1102,35 @@ func (s *Server) Close() error {
|
|||||||
s.closing = nil
|
s.closing = nil
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
s.logger.Debug(ctx, "closing server done")
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown gracefully closes all active SSH connections and stops
|
// Shutdown stops accepting new connections. The current implementation
|
||||||
// accepting new connections.
|
// calls Close() for simplicity instead of waiting for existing
|
||||||
//
|
// connections to close. If the context times out, Shutdown will return
|
||||||
// Shutdown is not implemented.
|
// but Close() may not have completed.
|
||||||
func (*Server) Shutdown(_ context.Context) error {
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
|
ch := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
|
||||||
|
// For now we just close the server.
|
||||||
|
ch <- s.Close()
|
||||||
|
}()
|
||||||
|
var err error
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
case err = <-ch:
|
||||||
|
}
|
||||||
|
// Re-check for context cancellation precedence.
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = ctx.Err()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("close server: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/agent/agentexec"
|
"github.com/coder/coder/v2/agent/agentexec"
|
||||||
@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin
|
|||||||
func TestNewServer_CloseActiveConnections(t *testing.T) {
|
func TestNewServer_CloseActiveConnections(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ctx := context.Background()
|
prepare := func(ctx context.Context, t *testing.T) (*agentssh.Server, func()) {
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
t.Helper()
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||||
require.NoError(t, err)
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
defer s.Close()
|
require.NoError(t, err)
|
||||||
err = s.UpdateHostSigner(42)
|
defer s.Close()
|
||||||
assert.NoError(t, err)
|
err = s.UpdateHostSigner(42)
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
err := s.Serve(ln)
|
|
||||||
assert.Error(t, err) // Server is closed.
|
|
||||||
}()
|
|
||||||
|
|
||||||
pty := ptytest.New(t)
|
|
||||||
|
|
||||||
doClose := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
c := sshClient(t, ln.Addr().String())
|
|
||||||
sess, err := c.NewSession()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
sess.Stdin = pty.Input()
|
|
||||||
sess.Stdout = pty.Output()
|
|
||||||
sess.Stderr = pty.Output()
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = sess.Start("")
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
close(doClose)
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
err = sess.Wait()
|
require.NoError(t, err)
|
||||||
assert.Error(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-doClose
|
waitConns := make([]chan struct{}, 4)
|
||||||
err = s.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
wg.Wait()
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1 + len(waitConns))
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := s.Serve(ln)
|
||||||
|
assert.Error(t, err) // Server is closed.
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < len(waitConns); i++ {
|
||||||
|
waitConns[i] = make(chan struct{})
|
||||||
|
go func(ch chan struct{}) {
|
||||||
|
defer wg.Done()
|
||||||
|
c := sshClient(t, ln.Addr().String())
|
||||||
|
sess, err := c.NewSession()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
pty := ptytest.New(t)
|
||||||
|
sess.Stdin = pty.Input()
|
||||||
|
sess.Stdout = pty.Output()
|
||||||
|
sess.Stderr = pty.Output()
|
||||||
|
|
||||||
|
// Every other session will request a PTY.
|
||||||
|
if i%2 == 0 {
|
||||||
|
err = sess.RequestPty("xterm", 80, 80, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
// The 60 seconds here is intended to be longer than the
|
||||||
|
// test. The shutdown should propagate.
|
||||||
|
err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; sleep 60'")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
close(ch)
|
||||||
|
err = sess.Wait()
|
||||||
|
assert.Error(t, err)
|
||||||
|
}(waitConns[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ch := range waitConns {
|
||||||
|
<-ch
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, wg.Wait
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Close", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
s, wait := prepare(ctx, t)
|
||||||
|
err := s.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Shutdown", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
s, wait := prepare(ctx, t)
|
||||||
|
err := s.Shutdown(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Shutdown Early", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
s, wait := prepare(ctx, t)
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
cancel()
|
||||||
|
err := s.Shutdown(ctx)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
wait()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewServer_Signal(t *testing.T) {
|
func TestNewServer_Signal(t *testing.T) {
|
||||||
|
24
agent/agentssh/exec_other.go
Normal file
24
agent/agentssh/exec_other.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package agentssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cmdSysProcAttr() *syscall.SysProcAttr {
|
||||||
|
return &syscall.SysProcAttr{
|
||||||
|
Setsid: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
|
||||||
|
return func() error {
|
||||||
|
logger.Debug(ctx, "cmdCancel: sending SIGHUP to process and children", slog.F("pid", cmd.Process.Pid))
|
||||||
|
return syscall.Kill(-cmd.Process.Pid, syscall.SIGHUP)
|
||||||
|
}
|
||||||
|
}
|
21
agent/agentssh/exec_windows.go
Normal file
21
agent/agentssh/exec_windows.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package agentssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cmdSysProcAttr() *syscall.SysProcAttr {
|
||||||
|
return &syscall.SysProcAttr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdCancel(ctx context.Context, logger slog.Logger, cmd *exec.Cmd) func() error {
|
||||||
|
return func() error {
|
||||||
|
logger.Debug(ctx, "cmdCancel: sending interrupt to process", slog.F("pid", cmd.Process.Pid))
|
||||||
|
return cmd.Process.Signal(os.Interrupt)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user