fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds (#3354)

* fix: Improve shutdown procedure of ssh, portforward, wgtunnel cmds

We could turn it into a practice to wrap `cmd.Context()` so that we have
more fine-grained control of cancellation. Sometimes in tests we may be
running commands with a context that is never canceled.

Related to #3221

* fix: Set ssh session stderr to stderr
This commit is contained in:
Mathias Fredriksson
2022-08-02 17:44:59 +03:00
committed by GitHub
parent 5ae19f097e
commit 83c63d4a63
4 changed files with 75 additions and 53 deletions

View File

@ -55,6 +55,9 @@ func portForward() *cobra.Command {
}, },
), ),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards) specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
if err != nil { if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err) return xerrors.Errorf("parse port-forward specs: %w", err)
@ -72,7 +75,7 @@ func portForward() *cobra.Command {
return err return err
} }
workspace, agent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false) workspace, agent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil { if err != nil {
return err return err
} }
@ -80,13 +83,13 @@ func portForward() *cobra.Command {
return xerrors.New("workspace must be in start transition to port-forward") return xerrors.New("workspace must be in start transition to port-forward")
} }
if workspace.LatestBuild.Job.CompletedAt == nil { if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil { if err != nil {
return err return err
} }
} }
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, agent.ID) return client.WorkspaceAgent(ctx, agent.ID)
@ -96,7 +99,7 @@ func portForward() *cobra.Command {
return xerrors.Errorf("await agent: %w", err) return xerrors.Errorf("await agent: %w", err)
} }
conn, err := client.DialWorkspaceAgent(cmd.Context(), agent.ID, nil) conn, err := client.DialWorkspaceAgent(ctx, agent.ID, nil)
if err != nil { if err != nil {
return xerrors.Errorf("dial workspace agent: %w", err) return xerrors.Errorf("dial workspace agent: %w", err)
} }
@ -104,7 +107,6 @@ func portForward() *cobra.Command {
// Start all listeners. // Start all listeners.
var ( var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup) wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs)) listeners = make([]net.Listener, len(specs))
closeAllListeners = func() { closeAllListeners = func() {
@ -116,11 +118,11 @@ func portForward() *cobra.Command {
} }
} }
) )
defer cancel() defer closeAllListeners()
for i, spec := range specs { for i, spec := range specs {
l, err := listenAndPortForward(ctx, cmd, conn, wg, spec) l, err := listenAndPortForward(ctx, cmd, conn, wg, spec)
if err != nil { if err != nil {
closeAllListeners()
return err return err
} }
listeners[i] = l listeners[i] = l
@ -129,7 +131,10 @@ func portForward() *cobra.Command {
// Wait for the context to be canceled or for a signal and close // Wait for the context to be canceled or for a signal and close
// all listeners. // all listeners.
var closeErr error var closeErr error
wg.Add(1)
go func() { go func() {
defer wg.Done()
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

View File

@ -51,6 +51,9 @@ func ssh() *cobra.Command {
Short: "SSH into a workspace", Short: "SSH into a workspace",
Args: cobra.ArbitraryArgs, Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
client, err := createClient(cmd) client, err := createClient(cmd)
if err != nil { if err != nil {
return err return err
@ -68,14 +71,14 @@ func ssh() *cobra.Command {
} }
} }
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], shuffle) workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle)
if err != nil { if err != nil {
return err return err
} }
// OpenSSH passes stderr directly to the calling TTY. // OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed. // This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID) return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@ -85,19 +88,16 @@ func ssh() *cobra.Command {
return xerrors.Errorf("await agent: %w", err) return xerrors.Errorf("await agent: %w", err)
} }
var ( var newSSHClient func() (*gossh.Client, error)
sshClient *gossh.Client
sshSession *gossh.Session
)
if !wireguard { if !wireguard {
conn, err := client.DialWorkspaceAgent(cmd.Context(), workspaceAgent.ID, nil) conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() defer conn.Close()
stopPolling := tryPollWorkspaceAutostop(cmd.Context(), client, workspace) stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
defer stopPolling() defer stopPolling()
if stdio { if stdio {
@ -105,6 +105,8 @@ func ssh() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
defer rawSSH.Close()
go func() { go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH) _, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}() }()
@ -112,15 +114,7 @@ func ssh() *cobra.Command {
return nil return nil
} }
sshClient, err = conn.SSHClient() newSSHClient = conn.SSHClient
if err != nil {
return err
}
sshSession, err = sshClient.NewSession()
if err != nil {
return err
}
} else { } else {
// TODO: more granual control of Tailscale logging. // TODO: more granual control of Tailscale logging.
peerwg.Logf = tslogger.Discard peerwg.Logf = tslogger.Discard
@ -133,8 +127,9 @@ func ssh() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("create wireguard network: %w", err) return xerrors.Errorf("create wireguard network: %w", err)
} }
defer wgn.Close()
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID, Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(), NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey, DiscoPublicKey: wgn.DiscoPublicKey,
@ -155,10 +150,11 @@ func ssh() *cobra.Command {
} }
if stdio { if stdio {
rawSSH, err := wgn.SSH(cmd.Context(), workspaceAgent.IPv6.IP()) rawSSH, err := wgn.SSH(ctx, workspaceAgent.IPv6.IP())
if err != nil { if err != nil {
return err return err
} }
defer rawSSH.Close()
go func() { go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH) _, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
@ -167,17 +163,30 @@ func ssh() *cobra.Command {
return nil return nil
} }
sshClient, err = wgn.SSHClient(cmd.Context(), workspaceAgent.IPv6.IP()) newSSHClient = func() (*gossh.Client, error) {
if err != nil { return wgn.SSHClient(ctx, workspaceAgent.IPv6.IP())
return err
}
sshSession, err = sshClient.NewSession()
if err != nil {
return err
} }
} }
sshClient, err := newSSHClient()
if err != nil {
return err
}
defer sshClient.Close()
sshSession, err := sshClient.NewSession()
if err != nil {
return err
}
defer sshSession.Close()
// Ensure context cancellation is propagated to the
// SSH session, e.g. to cancel `Wait()` at the end.
go func() {
<-ctx.Done()
_ = sshSession.Close()
}()
if identityAgent == "" { if identityAgent == "" {
identityAgent = os.Getenv("SSH_AUTH_SOCK") identityAgent = os.Getenv("SSH_AUTH_SOCK")
} }
@ -203,15 +212,18 @@ func ssh() *cobra.Command {
_ = term.Restore(int(stdinFile.Fd()), state) _ = term.Restore(int(stdinFile.Fd()), state)
}() }()
windowChange := listenWindowSize(cmd.Context()) windowChange := listenWindowSize(ctx)
go func() { go func() {
for { for {
select { select {
case <-cmd.Context().Done(): case <-ctx.Done():
return return
case <-windowChange: case <-windowChange:
} }
width, height, _ := term.GetSize(int(stdoutFile.Fd())) width, height, err := term.GetSize(int(stdoutFile.Fd()))
if err != nil {
continue
}
_ = sshSession.WindowChange(height, width) _ = sshSession.WindowChange(height, width)
} }
}() }()
@ -224,13 +236,17 @@ func ssh() *cobra.Command {
sshSession.Stdin = cmd.InOrStdin() sshSession.Stdin = cmd.InOrStdin()
sshSession.Stdout = cmd.OutOrStdout() sshSession.Stdout = cmd.OutOrStdout()
sshSession.Stderr = cmd.OutOrStdout() sshSession.Stderr = cmd.ErrOrStderr()
err = sshSession.Shell() err = sshSession.Shell()
if err != nil { if err != nil {
return err return err
} }
// Put cancel at the top of the defer stack to initiate
// shutdown of services.
defer cancel()
err = sshSession.Wait() err = sshSession.Wait()
if err != nil { if err != nil {
// If the connection drops unexpectedly, we get an ExitMissingError but no other // If the connection drops unexpectedly, we get an ExitMissingError but no other
@ -259,16 +275,14 @@ func ssh() *cobra.Command {
// getWorkspaceAgent returns the workspace and agent selected using either the // getWorkspaceAgent returns the workspace and agent selected using either the
// `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent // `<workspace>[.<agent>]` syntax via `in` or picks a random workspace and agent
// if `shuffle` is true. // if `shuffle` is true.
func getWorkspaceAndAgent(cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive
ctx := cmd.Context()
var ( var (
workspace codersdk.Workspace workspace codersdk.Workspace
workspaceParts = strings.Split(in, ".") workspaceParts = strings.Split(in, ".")
err error err error
) )
if shuffle { if shuffle {
workspaces, err := client.Workspaces(cmd.Context(), codersdk.WorkspaceFilter{ workspaces, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{
Owner: codersdk.Me, Owner: codersdk.Me,
}) })
if err != nil { if err != nil {

View File

@ -229,7 +229,7 @@ func TestSSH(t *testing.T) {
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output()) cmd.SetOut(pty.Output())
cmd.SetErr(io.Discard) cmd.SetErr(pty.Output())
cmdDone := tGo(t, func() { cmdDone := tGo(t, func() {
err := cmd.ExecuteContext(ctx) err := cmd.ExecuteContext(ctx)
assert.NoError(t, err) assert.NoError(t, err)
@ -248,9 +248,6 @@ func TestSSH(t *testing.T) {
// And we're done. // And we're done.
pty.WriteLine("exit") pty.WriteLine("exit")
// Read output to prevent hang on macOS, see:
// https://github.com/coder/coder/issues/2122
pty.ExpectMatch("exit")
<-cmdDone <-cmdDone
}) })
} }

View File

@ -52,6 +52,9 @@ func wireguardPortForward() *cobra.Command {
}, },
), ),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
specs, err := parsePortForwards(tcpForwards, nil, nil) specs, err := parsePortForwards(tcpForwards, nil, nil)
if err != nil { if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err) return xerrors.Errorf("parse port-forward specs: %w", err)
@ -69,7 +72,7 @@ func wireguardPortForward() *cobra.Command {
return err return err
} }
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false) workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false)
if err != nil { if err != nil {
return err return err
} }
@ -77,13 +80,13 @@ func wireguardPortForward() *cobra.Command {
return xerrors.New("workspace must be in start transition to port-forward") return xerrors.New("workspace must be in start transition to port-forward")
} }
if workspace.LatestBuild.Job.CompletedAt == nil { if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt) err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil { if err != nil {
return err return err
} }
} }
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{ err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID) return client.WorkspaceAgent(ctx, workspaceAgent.ID)
@ -101,8 +104,9 @@ func wireguardPortForward() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("create wireguard network: %w", err) return xerrors.Errorf("create wireguard network: %w", err)
} }
defer wgn.Close()
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{ err = client.PostWireguardPeer(ctx, workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID, Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(), NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey, DiscoPublicKey: wgn.DiscoPublicKey,
@ -124,7 +128,6 @@ func wireguardPortForward() *cobra.Command {
// Start all listeners. // Start all listeners.
var ( var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup) wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs)) listeners = make([]net.Listener, len(specs))
closeAllListeners = func() { closeAllListeners = func() {
@ -136,11 +139,11 @@ func wireguardPortForward() *cobra.Command {
} }
} }
) )
defer cancel() defer closeAllListeners()
for i, spec := range specs { for i, spec := range specs {
l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP()) l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP())
if err != nil { if err != nil {
closeAllListeners()
return err return err
} }
listeners[i] = l listeners[i] = l
@ -149,7 +152,10 @@ func wireguardPortForward() *cobra.Command {
// Wait for the context to be canceled or for a signal and close // Wait for the context to be canceled or for a signal and close
// all listeners. // all listeners.
var closeErr error var closeErr error
wg.Add(1)
go func() { go func() {
defer wg.Done()
sigs := make(chan os.Signal, 1) sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)