feat: Add SSH agent forwarding support to coder agent (#1548)

* feat: Add SSH agent forwarding support to coder agent

* feat: Add forward agent flag to `coder ssh`

* refactor: Share setup between SSH tests, sync goroutines

* feat: Add test for `coder ssh --forward-agent`

* fix: Fix test flakes and implement Deans suggestion for helpers

* fix: Add example to config-ssh

* fix: Allow forwarding agent via -A

Co-authored-by: Cian Johnston <cian@coder.com>
This commit is contained in:
Mathias Fredriksson
2022-05-25 21:28:10 +03:00
committed by GitHub
parent 22ef456164
commit 527f1f3bc3
4 changed files with 211 additions and 69 deletions

View File

@ -391,6 +391,16 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
return err return err
} }
if ssh.AgentRequested(session) {
l, err := ssh.NewAgentListener()
if err != nil {
return xerrors.Errorf("new agent listener: %w", err)
}
defer l.Close()
go ssh.ForwardAgentConnections(l, session)
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
}
sshPty, windowSize, isPty := session.Pty() sshPty, windowSize, isPty := session.Pty()
if isPty { if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))

View File

@ -38,6 +38,11 @@ func configSSH() *cobra.Command {
Annotations: workspaceCommand, Annotations: workspaceCommand,
Use: "config-ssh", Use: "config-ssh",
Short: "Populate your SSH config with Host entries for all of your workspaces", Short: "Populate your SSH config with Host entries for all of your workspaces",
Example: `
- You can use -o (or --ssh-option) so set SSH options to be used for all your
workspaces.
` + cliui.Styles.Code.Render("$ coder config-ssh -o ForwardAgent=yes"),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient(cmd) client, err := createClient(cmd)
if err != nil { if err != nil {

View File

@ -15,6 +15,7 @@ import (
"github.com/mattn/go-isatty" "github.com/mattn/go-isatty"
"github.com/spf13/cobra" "github.com/spf13/cobra"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
gosshagent "golang.org/x/crypto/ssh/agent"
"golang.org/x/term" "golang.org/x/term"
"golang.org/x/xerrors" "golang.org/x/xerrors"
@ -32,6 +33,7 @@ func ssh() *cobra.Command {
var ( var (
stdio bool stdio bool
shuffle bool shuffle bool
forwardAgent bool
wsPollInterval time.Duration wsPollInterval time.Duration
) )
cmd := &cobra.Command{ cmd := &cobra.Command{
@ -108,6 +110,17 @@ func ssh() *cobra.Command {
return err return err
} }
if forwardAgent && os.Getenv("SSH_AUTH_SOCK") != "" {
err = gosshagent.ForwardToRemote(sshClient, os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
return xerrors.Errorf("forward agent failed: %w", err)
}
err = gosshagent.RequestAgentForwarding(sshSession)
if err != nil {
return xerrors.Errorf("request agent forwarding failed: %w", err)
}
}
stdoutFile, valid := cmd.OutOrStdout().(*os.File) stdoutFile, valid := cmd.OutOrStdout().(*os.File)
if valid && isatty.IsTerminal(stdoutFile.Fd()) { if valid && isatty.IsTerminal(stdoutFile.Fd()) {
state, err := term.MakeRaw(int(os.Stdin.Fd())) state, err := term.MakeRaw(int(os.Stdin.Fd()))
@ -156,8 +169,9 @@ func ssh() *cobra.Command {
} }
cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.") cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.")
cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace") cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace")
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
_ = cmd.Flags().MarkHidden("shuffle") _ = cmd.Flags().MarkHidden("shuffle")
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
return cmd return cmd
} }

View File

@ -1,8 +1,14 @@
package cli_test package cli_test
import ( import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"io" "io"
"net" "net"
"path/filepath"
"runtime" "runtime"
"testing" "testing"
"time" "time"
@ -11,9 +17,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
gosshagent "golang.org/x/crypto/ssh/agent"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent" "github.com/coder/coder/agent"
"github.com/coder/coder/cli/clitest" "github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/coderdtest"
@ -23,49 +31,53 @@ import (
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
) )
func setupWorkspaceForSSH(t *testing.T) (*codersdk.Client, codersdk.Workspace, string) {
t.Helper()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
user := coderdtest.CreateFirstUser(t, client)
agentToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "dev",
Type: "google_compute_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: agentToken,
},
}},
}},
},
},
}},
})
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
return client, workspace, agentToken
}
func TestSSH(t *testing.T) { func TestSSH(t *testing.T) {
t.Skip("This is causing test flakes. TODO @cian fix this")
t.Parallel() t.Parallel()
t.Run("ImmediateExit", func(t *testing.T) { t.Run("ImmediateExit", func(t *testing.T) {
t.Parallel() t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) client, workspace, agentToken := setupWorkspaceForSSH(t)
user := coderdtest.CreateFirstUser(t, client)
agentToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "dev",
Type: "google_compute_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: agentToken,
},
}},
}},
},
},
}},
})
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
cmd, root := clitest.New(t, "ssh", workspace.Name) cmd, root := clitest.New(t, "ssh", workspace.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t) pty := ptytest.New(t)
cmd.SetIn(pty.Input()) cmd.SetIn(pty.Input())
cmd.SetErr(pty.Output()) cmd.SetErr(pty.Output())
cmd.SetOut(pty.Output()) cmd.SetOut(pty.Output())
go func() { cmdDone := tGo(t, func() {
defer close(doneChan)
err := cmd.Execute() err := cmd.Execute()
assert.NoError(t, err) assert.NoError(t, err)
}() })
pty.ExpectMatch("Waiting") pty.ExpectMatch("Waiting")
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL) agentClient := codersdk.New(client.URL)
@ -76,39 +88,16 @@ func TestSSH(t *testing.T) {
t.Cleanup(func() { t.Cleanup(func() {
_ = agentCloser.Close() _ = agentCloser.Close()
}) })
// Shells on Mac, Windows, and Linux all exit shells with the "exit" command. // Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
pty.WriteLine("exit") pty.WriteLine("exit")
<-doneChan <-cmdDone
}) })
t.Run("Stdio", func(t *testing.T) { t.Run("Stdio", func(t *testing.T) {
t.Parallel() t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) client, workspace, agentToken := setupWorkspaceForSSH(t)
user := coderdtest.CreateFirstUser(t, client)
agentToken := uuid.NewString() _, _ = tGoContext(t, func(ctx context.Context) {
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "dev",
Type: "google_compute_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: agentToken,
},
}},
}},
},
},
}},
})
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
go func() {
// Run this async so the SSH command has to wait for // Run this async so the SSH command has to wait for
// the build and agent to connect! // the build and agent to connect!
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
@ -117,25 +106,22 @@ func TestSSH(t *testing.T) {
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{ agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
}) })
t.Cleanup(func() { <-ctx.Done()
_ = agentCloser.Close() _ = agentCloser.Close()
}) })
}()
clientOutput, clientInput := io.Pipe() clientOutput, clientInput := io.Pipe()
serverOutput, serverInput := io.Pipe() serverOutput, serverInput := io.Pipe()
cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name) cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
clitest.SetupConfig(t, client, root) clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
cmd.SetIn(clientOutput) cmd.SetIn(clientOutput)
cmd.SetOut(serverInput) cmd.SetOut(serverInput)
cmd.SetErr(io.Discard) cmd.SetErr(io.Discard)
go func() { cmdDone := tGo(t, func() {
defer close(doneChan)
err := cmd.Execute() err := cmd.Execute()
assert.NoError(t, err) assert.NoError(t, err)
}() })
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
Reader: serverOutput, Reader: serverOutput,
@ -157,8 +143,135 @@ func TestSSH(t *testing.T) {
err = sshClient.Close() err = sshClient.Close()
require.NoError(t, err) require.NoError(t, err)
_ = clientOutput.Close() _ = clientOutput.Close()
<-doneChan
<-cmdDone
}) })
//nolint:paralleltest // Disabled due to use of t.Setenv.
t.Run("ForwardAgent", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Test not supported on windows")
}
client, workspace, agentToken := setupWorkspaceForSSH(t)
_, _ = tGoContext(t, func(ctx context.Context) {
// Run this async so the SSH command has to wait for
// the build and agent to connect!
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
})
<-ctx.Done()
_ = agentCloser.Close()
})
// Generate private key.
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
kr := gosshagent.NewKeyring()
kr.Add(gosshagent.AddedKey{
PrivateKey: privateKey,
})
// Start up ssh agent listening on unix socket.
tmpdir := t.TempDir()
agentSock := filepath.Join(tmpdir, "agent.sock")
l, err := net.Listen("unix", agentSock)
require.NoError(t, err)
defer l.Close()
_ = tGo(t, func() {
for {
fd, err := l.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Logf("accept error: %v", err)
}
return
}
err = gosshagent.ServeAgent(kr, fd)
if !errors.Is(err, io.EOF) {
assert.NoError(t, err)
}
}
})
t.Setenv("SSH_AUTH_SOCK", agentSock)
cmd, root := clitest.New(t,
"ssh",
workspace.Name,
"--forward-agent",
)
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
cmd.SetErr(io.Discard)
cmdDone := tGo(t, func() {
err := cmd.Execute()
assert.NoError(t, err)
})
// Ensure that SSH_AUTH_SOCK is set.
// Linux: /tmp/auth-agent3167016167/listener.sock
// macOS: /var/folders/ng/m1q0wft14hj0t3rtjxrdnzsr0000gn/T/auth-agent3245553419/listener.sock
pty.WriteLine("env")
pty.ExpectMatch("SSH_AUTH_SOCK=")
// Ensure that ssh-add lists our key.
pty.WriteLine("ssh-add -L")
keys, err := kr.List()
require.NoError(t, err)
pty.ExpectMatch(keys[0].String())
// And we're done.
pty.WriteLine("exit")
<-cmdDone
})
}
// tGoContext runs fn in a goroutine passing a context that will be
// canceled on test completion and wait until fn has finished executing.
// Done and cancel are returned for optionally waiting until completion
// or early cancellation.
//
// NOTE(mafredri): This could be moved to a helper library.
func tGoContext(t *testing.T, fn func(context.Context)) (done <-chan struct{}, cancel context.CancelFunc) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
doneC := make(chan struct{})
t.Cleanup(func() {
cancel()
<-done
})
go func() {
fn(ctx)
close(doneC)
}()
return doneC, cancel
}
// tGo runs fn in a goroutine and waits until fn has completed before
// test completion. Done is returned for optionally waiting for fn to
// exit.
//
// NOTE(mafredri): This could be moved to a helper library.
func tGo(t *testing.T, fn func()) (done <-chan struct{}) {
t.Helper()
doneC := make(chan struct{})
t.Cleanup(func() {
<-doneC
})
go func() {
fn()
close(doneC)
}()
return doneC
} }
type stdioConn struct { type stdioConn struct {