mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
fix: close ssh sessions gracefully (#10732)
Re-enables TestSSH/RemoteForward_Unix_Signal and addresses the underlying race: we were not closing the remote forward on context expiry, only the session and connection. However, there is still a more fundamental issue in that we don't have the ability to ensure that TCP sessions are properly terminated before tearing down the Tailnet conn. This is due to the assumption in the sockets API, that the underlying IP interface is long lived compared with the TCP socket, and thus closing a socket returns immediately and does not wait for the TCP termination handshake --- that is handled async in the tcpip stack. However, this assumption does not hold for us and tailnet, since on shutdown, we also tear down the tailnet connection, and this can race with the TCP termination. Closing the remote forward explicitly should prevent forward state from accumulating, since the Close() function waits for a reply from the remote SSH server. I've also attempted to workaround the TCP/tailnet issue for `--stdio` by using `CloseWrite()` instead of `Close()`. By closing the write side of the connection, half-close the TCP connection, and the server detects this and closes the other direction, which then triggers our read loop to exit only after the server has had a chance to process the close. TODO in a stacked PR is to implement this logic for `vscodessh` as well.
This commit is contained in:
161
cli/ssh.go
161
cli/ssh.go
@ -22,6 +22,7 @@ import (
|
|||||||
gosshagent "golang.org/x/crypto/ssh/agent"
|
gosshagent "golang.org/x/crypto/ssh/agent"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"cdr.dev/slog/sloggers/sloghuman"
|
"cdr.dev/slog/sloggers/sloghuman"
|
||||||
@ -129,6 +130,8 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||||||
// log HTTP requests
|
// log HTTP requests
|
||||||
client.SetLogger(logger)
|
client.SetLogger(logger)
|
||||||
}
|
}
|
||||||
|
stack := newCloserStack(ctx, logger)
|
||||||
|
defer stack.close(nil)
|
||||||
|
|
||||||
if remoteForward != "" {
|
if remoteForward != "" {
|
||||||
isValid := validateRemoteForward(remoteForward)
|
isValid := validateRemoteForward(remoteForward)
|
||||||
@ -212,7 +215,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("dial agent: %w", err)
|
return xerrors.Errorf("dial agent: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
if err = stack.push("agent conn", conn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
conn.AwaitReachable(ctx)
|
conn.AwaitReachable(ctx)
|
||||||
|
|
||||||
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
|
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
|
||||||
@ -223,36 +228,20 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("connect SSH: %w", err)
|
return xerrors.Errorf("connect SSH: %w", err)
|
||||||
}
|
}
|
||||||
defer rawSSH.Close()
|
copier := &rawSSHCopier{conn: rawSSH, r: inv.Stdin, w: inv.Stdout}
|
||||||
|
if err = stack.push("rawSSHCopier", copier); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
watchAndClose(ctx, func() error {
|
watchAndClose(ctx, func() error {
|
||||||
return rawSSH.Close()
|
stack.close(xerrors.New("watchAndClose"))
|
||||||
|
return nil
|
||||||
}, logger, client, workspace)
|
}, logger, client, workspace)
|
||||||
}()
|
}()
|
||||||
|
copier.copy(&wg)
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
// Ensure stdout copy closes incase stdin is closed
|
|
||||||
// unexpectedly.
|
|
||||||
defer rawSSH.Close()
|
|
||||||
|
|
||||||
_, err := io.Copy(rawSSH, inv.Stdin)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "copy stdin error", slog.Error(err))
|
|
||||||
} else {
|
|
||||||
logger.Debug(ctx, "copy stdin complete")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
_, err = io.Copy(inv.Stdout, rawSSH)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "copy stdout error", slog.Error(err))
|
|
||||||
} else {
|
|
||||||
logger.Debug(ctx, "copy stdout complete")
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,13 +249,17 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("ssh client: %w", err)
|
return xerrors.Errorf("ssh client: %w", err)
|
||||||
}
|
}
|
||||||
defer sshClient.Close()
|
if err = stack.push("ssh client", sshClient); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
sshSession, err := sshClient.NewSession()
|
sshSession, err := sshClient.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("ssh session: %w", err)
|
return xerrors.Errorf("ssh session: %w", err)
|
||||||
}
|
}
|
||||||
defer sshSession.Close()
|
if err = stack.push("sshSession", sshSession); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@ -274,10 +267,7 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||||||
watchAndClose(
|
watchAndClose(
|
||||||
ctx,
|
ctx,
|
||||||
func() error {
|
func() error {
|
||||||
err := sshSession.Close()
|
stack.close(xerrors.New("watchAndClose"))
|
||||||
logger.Debug(ctx, "session close", slog.Error(err))
|
|
||||||
err = sshClient.Close()
|
|
||||||
logger.Debug(ctx, "client close", slog.Error(err))
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
logger,
|
logger,
|
||||||
@ -313,7 +303,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("forward GPG socket: %w", err)
|
return xerrors.Errorf("forward GPG socket: %w", err)
|
||||||
}
|
}
|
||||||
defer closer.Close()
|
if err = stack.push("forwardGPGAgent", closer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if remoteForward != "" {
|
if remoteForward != "" {
|
||||||
@ -326,7 +318,9 @@ func (r *RootCmd) ssh() *clibase.Cmd {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("ssh remote forward: %w", err)
|
return xerrors.Errorf("ssh remote forward: %w", err)
|
||||||
}
|
}
|
||||||
defer closer.Close()
|
if err = stack.push("sshRemoteForward", closer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stdoutFile, validOut := inv.Stdout.(*os.File)
|
stdoutFile, validOut := inv.Stdout.(*os.File)
|
||||||
@ -795,3 +789,106 @@ func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) {
|
|||||||
|
|
||||||
return string(bytes.TrimSpace(remoteSocket)), nil
|
return string(bytes.TrimSpace(remoteSocket)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type closerWithName struct {
|
||||||
|
name string
|
||||||
|
closer io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
type closerStack struct {
|
||||||
|
sync.Mutex
|
||||||
|
closers []closerWithName
|
||||||
|
closed bool
|
||||||
|
logger slog.Logger
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCloserStack(ctx context.Context, logger slog.Logger) *closerStack {
|
||||||
|
cs := &closerStack{logger: logger}
|
||||||
|
go cs.closeAfterContext(ctx)
|
||||||
|
return cs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closerStack) closeAfterContext(ctx context.Context) {
|
||||||
|
<-ctx.Done()
|
||||||
|
c.close(ctx.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closerStack) close(err error) {
|
||||||
|
c.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
c.err = err
|
||||||
|
c.Unlock()
|
||||||
|
|
||||||
|
for i := len(c.closers) - 1; i >= 0; i-- {
|
||||||
|
cwn := c.closers[i]
|
||||||
|
cErr := cwn.closer.Close()
|
||||||
|
c.logger.Debug(context.Background(),
|
||||||
|
"closed item from stack", slog.F("name", cwn.name), slog.Error(cErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closerStack) push(name string, closer io.Closer) error {
|
||||||
|
c.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.Unlock()
|
||||||
|
// since we're refusing to push it on the stack, close it now
|
||||||
|
err := closer.Close()
|
||||||
|
c.logger.Error(context.Background(),
|
||||||
|
"closed item rejected push", slog.F("name", name), slog.Error(err))
|
||||||
|
return xerrors.Errorf("already closed: %w", c.err)
|
||||||
|
}
|
||||||
|
c.closers = append(c.closers, closerWithName{name: name, closer: closer})
|
||||||
|
c.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawSSHCopier handles copying raw SSH data between the conn and the pair (r, w).
|
||||||
|
type rawSSHCopier struct {
|
||||||
|
conn *gonet.TCPConn
|
||||||
|
logger slog.Logger
|
||||||
|
r io.Reader
|
||||||
|
w io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *rawSSHCopier) copy(wg *sync.WaitGroup) {
|
||||||
|
logCtx := context.Background()
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
// We close connections using CloseWrite instead of Close, so that the SSH server sees the
|
||||||
|
// closed connection while reading, and shuts down cleanly. This will trigger the io.Copy
|
||||||
|
// in the server-to-client direction to also be closed and the copy() routine will exit.
|
||||||
|
// This ensures that we don't leave any state in the server, like forwarded ports if
|
||||||
|
// copy() were to return and the underlying tailnet connection torn down before the TCP
|
||||||
|
// session exits. This is a bit of a hack to block shut down at the application layer, since
|
||||||
|
// we can't serialize the TCP and tailnet layers shutting down.
|
||||||
|
//
|
||||||
|
// Of course, if the underlying transport is broken, io.Copy will still return.
|
||||||
|
defer func() {
|
||||||
|
cwErr := c.conn.CloseWrite()
|
||||||
|
c.logger.Debug(logCtx, "closed raw SSH connection for writing", slog.Error(cwErr))
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := io.Copy(c.conn, c.r)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error(logCtx, "copy stdin error", slog.Error(err))
|
||||||
|
} else {
|
||||||
|
c.logger.Debug(logCtx, "copy stdin complete")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
_, err := io.Copy(c.w, c.conn)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error(logCtx, "copy stdout error", slog.Error(err))
|
||||||
|
} else {
|
||||||
|
c.logger.Debug(logCtx, "copy stdout complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *rawSSHCopier) Close() error {
|
||||||
|
return c.conn.CloseWrite()
|
||||||
|
}
|
||||||
|
@ -1,9 +1,16 @@
|
|||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@ -56,3 +63,77 @@ func TestBuildWorkspaceLink(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, workspaceLink.String(), fakeServerURL+"/@"+fakeOwnerName+"/"+fakeWorkspaceName)
|
assert.Equal(t, workspaceLink.String(), fakeServerURL+"/@"+fakeOwnerName+"/"+fakeWorkspaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloserStack_Mainline(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
uut := newCloserStack(ctx, logger)
|
||||||
|
closes := new([]*fakeCloser)
|
||||||
|
fc0 := &fakeCloser{closes: closes}
|
||||||
|
fc1 := &fakeCloser{closes: closes}
|
||||||
|
|
||||||
|
func() {
|
||||||
|
defer uut.close(nil)
|
||||||
|
err := uut.push("fc0", fc0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = uut.push("fc1", fc1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
// order reversed
|
||||||
|
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloserStack_Context(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
uut := newCloserStack(ctx, logger)
|
||||||
|
closes := new([]*fakeCloser)
|
||||||
|
fc0 := &fakeCloser{closes: closes}
|
||||||
|
fc1 := &fakeCloser{closes: closes}
|
||||||
|
|
||||||
|
err := uut.push("fc0", fc0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = uut.push("fc1", fc1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
cancel()
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
uut.Lock()
|
||||||
|
defer uut.Unlock()
|
||||||
|
return uut.closed
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloserStack_PushAfterClose(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||||
|
uut := newCloserStack(ctx, logger)
|
||||||
|
closes := new([]*fakeCloser)
|
||||||
|
fc0 := &fakeCloser{closes: closes}
|
||||||
|
fc1 := &fakeCloser{closes: closes}
|
||||||
|
|
||||||
|
err := uut.push("fc0", fc0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
exErr := xerrors.New("test")
|
||||||
|
uut.close(exErr)
|
||||||
|
require.Equal(t, []*fakeCloser{fc0}, *closes)
|
||||||
|
|
||||||
|
err = uut.push("fc1", fc1)
|
||||||
|
require.ErrorIs(t, err, exErr)
|
||||||
|
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1")
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCloser struct {
|
||||||
|
closes *[]*fakeCloser
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeCloser) Close() error {
|
||||||
|
*c.closes = append(*c.closes, c)
|
||||||
|
return c.err
|
||||||
|
}
|
||||||
|
122
cli/ssh_test.go
122
cli/ssh_test.go
@ -249,10 +249,125 @@ func TestSSH(t *testing.T) {
|
|||||||
<-cmdDone
|
<-cmdDone
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Stdio_RemoteForward_Signal", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||||
|
_, _ = tGoContext(t, func(ctx context.Context) {
|
||||||
|
// Run this async so the SSH command has to wait for
|
||||||
|
// the build and agent to connect!
|
||||||
|
_ = agenttest.New(t, client.URL, agentToken)
|
||||||
|
<-ctx.Done()
|
||||||
|
})
|
||||||
|
|
||||||
|
clientOutput, clientInput := io.Pipe()
|
||||||
|
serverOutput, serverInput := io.Pipe()
|
||||||
|
defer func() {
|
||||||
|
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
|
||||||
|
fsn := clitest.NewFakeSignalNotifier(t)
|
||||||
|
inv = inv.WithTestSignalNotifyContext(t, fsn.NotifyContext)
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
inv.Stdin = clientOutput
|
||||||
|
inv.Stdout = serverInput
|
||||||
|
inv.Stderr = io.Discard
|
||||||
|
|
||||||
|
cmdDone := tGo(t, func() {
|
||||||
|
err := inv.WithContext(ctx).Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
|
||||||
|
Reader: serverOutput,
|
||||||
|
Writer: clientInput,
|
||||||
|
}, "", &ssh.ClientConfig{
|
||||||
|
// #nosec
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
sshClient := ssh.NewClient(conn, channels, requests)
|
||||||
|
|
||||||
|
tmpdir := tempDirUnixSocket(t)
|
||||||
|
|
||||||
|
remoteSock := path.Join(tmpdir, "remote.sock")
|
||||||
|
_, err = sshClient.ListenUnix(remoteSock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
fsn.Notify()
|
||||||
|
<-cmdDone
|
||||||
|
fsn.AssertStopped()
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
_, err = os.Stat(remoteSock)
|
||||||
|
return xerrors.Is(err, os.ErrNotExist)
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Stdio_BrokenConn", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
client, workspace, agentToken := setupWorkspaceForAgent(t, nil)
|
||||||
|
_, _ = tGoContext(t, func(ctx context.Context) {
|
||||||
|
// Run this async so the SSH command has to wait for
|
||||||
|
// the build and agent to connect!
|
||||||
|
_ = agenttest.New(t, client.URL, agentToken)
|
||||||
|
<-ctx.Done()
|
||||||
|
})
|
||||||
|
|
||||||
|
clientOutput, clientInput := io.Pipe()
|
||||||
|
serverOutput, serverInput := io.Pipe()
|
||||||
|
defer func() {
|
||||||
|
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
|
||||||
|
clitest.SetupConfig(t, client, root)
|
||||||
|
inv.Stdin = clientOutput
|
||||||
|
inv.Stdout = serverInput
|
||||||
|
inv.Stderr = io.Discard
|
||||||
|
|
||||||
|
cmdDone := tGo(t, func() {
|
||||||
|
err := inv.WithContext(ctx).Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
|
||||||
|
Reader: serverOutput,
|
||||||
|
Writer: clientInput,
|
||||||
|
}, "", &ssh.ClientConfig{
|
||||||
|
// #nosec
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
sshClient := ssh.NewClient(conn, channels, requests)
|
||||||
|
_ = serverOutput.Close()
|
||||||
|
_ = clientInput.Close()
|
||||||
|
select {
|
||||||
|
case <-cmdDone:
|
||||||
|
// OK
|
||||||
|
case <-time.After(testutil.WaitShort):
|
||||||
|
t.Error("timeout waiting for command to exit")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = sshClient.Close()
|
||||||
|
})
|
||||||
|
|
||||||
// Test that we handle OS signals properly while remote forwarding, and don't just leave the TCP
|
// Test that we handle OS signals properly while remote forwarding, and don't just leave the TCP
|
||||||
// socket hanging.
|
// socket hanging.
|
||||||
t.Run("RemoteForward_Unix_Signal", func(t *testing.T) {
|
t.Run("RemoteForward_Unix_Signal", func(t *testing.T) {
|
||||||
t.Skip("still flaky")
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("No unix sockets on windows")
|
t.Skip("No unix sockets on windows")
|
||||||
}
|
}
|
||||||
@ -578,12 +693,13 @@ func TestSSH(t *testing.T) {
|
|||||||
l, err := net.Listen("unix", agentSock)
|
l, err := net.Listen("unix", agentSock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
remoteSock := filepath.Join(tmpdir, "remote.sock")
|
||||||
|
|
||||||
inv, root := clitest.New(t,
|
inv, root := clitest.New(t,
|
||||||
"ssh",
|
"ssh",
|
||||||
workspace.Name,
|
workspace.Name,
|
||||||
"--remote-forward",
|
"--remote-forward",
|
||||||
"/tmp/test.sock:"+agentSock,
|
fmt.Sprintf("%s:%s", remoteSock, agentSock),
|
||||||
)
|
)
|
||||||
clitest.SetupConfig(t, client, root)
|
clitest.SetupConfig(t, client, root)
|
||||||
pty := ptytest.New(t).Attach(inv)
|
pty := ptytest.New(t).Attach(inv)
|
||||||
@ -598,7 +714,7 @@ func TestSSH(t *testing.T) {
|
|||||||
_ = pty.Peek(ctx, 1)
|
_ = pty.Peek(ctx, 1)
|
||||||
|
|
||||||
// Download the test page
|
// Download the test page
|
||||||
pty.WriteLine("ss -xl state listening src /tmp/test.sock | wc -l")
|
pty.WriteLine(fmt.Sprintf("ss -xl state listening src %s | wc -l", remoteSock))
|
||||||
pty.ExpectMatch("2")
|
pty.ExpectMatch("2")
|
||||||
|
|
||||||
// And we're done.
|
// And we're done.
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"tailscale.com/ipn/ipnstate"
|
"tailscale.com/ipn/ipnstate"
|
||||||
"tailscale.com/net/speedtest"
|
"tailscale.com/net/speedtest"
|
||||||
|
|
||||||
@ -249,7 +250,7 @@ func (c *WorkspaceAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID,
|
|||||||
|
|
||||||
// SSH pipes the SSH protocol over the returned net.Conn.
|
// SSH pipes the SSH protocol over the returned net.Conn.
|
||||||
// This connects to the built-in SSH server in the workspace agent.
|
// This connects to the built-in SSH server in the workspace agent.
|
||||||
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (net.Conn, error) {
|
func (c *WorkspaceAgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) {
|
||||||
ctx, span := tracing.StartSpan(ctx)
|
ctx, span := tracing.StartSpan(ctx)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user