feat: Add config-ssh command (#735)

* feat: Add config-ssh command

Closes #254 and #499.

* Fix Windows support
This commit is contained in:
Kyle Carberry
2022-03-30 17:59:54 -05:00
committed by GitHub
parent 6ab1a681c4
commit 6612e3c9c7
29 changed files with 554 additions and 115 deletions

View File

@ -1,5 +1,6 @@
{ {
"cSpell.words": [ "cSpell.words": [
"cliflag",
"cliui", "cliui",
"coderd", "coderd",
"coderdtest", "coderdtest",

View File

@ -56,24 +56,24 @@ type agent struct {
sshServer *ssh.Server sshServer *ssh.Server
} }
func (s *agent) run(ctx context.Context) { func (a *agent) run(ctx context.Context) {
var peerListener *peerbroker.Listener var peerListener *peerbroker.Listener
var err error var err error
// An exponential back-off occurs when the connection is failing to dial. // An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage. // This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = s.clientDialer(ctx, s.options) peerListener, err = a.clientDialer(ctx, a.options)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return return
} }
if s.isClosed() { if a.isClosed() {
return return
} }
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) a.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue continue
} }
s.options.Logger.Info(context.Background(), "connected") a.options.Logger.Info(context.Background(), "connected")
break break
} }
select { select {
@ -85,40 +85,40 @@ func (s *agent) run(ctx context.Context) {
for { for {
conn, err := peerListener.Accept() conn, err := peerListener.Accept()
if err != nil { if err != nil {
if s.isClosed() { if a.isClosed() {
return return
} }
s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) a.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
s.run(ctx) a.run(ctx)
return return
} }
s.closeMutex.Lock() a.closeMutex.Lock()
s.connCloseWait.Add(1) a.connCloseWait.Add(1)
s.closeMutex.Unlock() a.closeMutex.Unlock()
go s.handlePeerConn(ctx, conn) go a.handlePeerConn(ctx, conn)
} }
} }
func (s *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) { func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go func() { go func() {
<-conn.Closed() <-conn.Closed()
s.connCloseWait.Done() a.connCloseWait.Done()
}() }()
for { for {
channel, err := conn.Accept(ctx) channel, err := conn.Accept(ctx)
if err != nil { if err != nil {
if errors.Is(err, peer.ErrClosed) || s.isClosed() { if errors.Is(err, peer.ErrClosed) || a.isClosed() {
return return
} }
s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) a.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
return return
} }
switch channel.Protocol() { switch channel.Protocol() {
case "ssh": case "ssh":
s.sshServer.HandleConn(channel.NetConn()) a.sshServer.HandleConn(channel.NetConn())
default: default:
s.options.Logger.Warn(ctx, "unhandled protocol from channel", a.options.Logger.Warn(ctx, "unhandled protocol from channel",
slog.F("protocol", channel.Protocol()), slog.F("protocol", channel.Protocol()),
slog.F("label", channel.Label()), slog.F("label", channel.Label()),
) )
@ -126,7 +126,7 @@ func (s *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
} }
} }
func (s *agent) init(ctx context.Context) { func (a *agent) init(ctx context.Context) {
// Clients' should ignore the host key when connecting. // Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH, // The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security. // so SSH authentication doesn't improve security.
@ -138,17 +138,17 @@ func (s *agent) init(ctx context.Context) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
sshLogger := s.options.Logger.Named("ssh-server") sshLogger := a.options.Logger.Named("ssh-server")
forwardHandler := &ssh.ForwardedTCPHandler{} forwardHandler := &ssh.ForwardedTCPHandler{}
s.sshServer = &ssh.Server{ a.sshServer = &ssh.Server{
ChannelHandlers: ssh.DefaultChannelHandlers, ChannelHandlers: ssh.DefaultChannelHandlers,
ConnectionFailedCallback: func(conn net.Conn, err error) { ConnectionFailedCallback: func(conn net.Conn, err error) {
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
}, },
Handler: func(session ssh.Session) { Handler: func(session ssh.Session) {
err := s.handleSSHSession(session) err := a.handleSSHSession(session)
if err != nil { if err != nil {
s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err)) a.options.Logger.Warn(ctx, "ssh session failed", slog.Error(err))
_ = session.Exit(1) _ = session.Exit(1)
return return
} }
@ -177,35 +177,26 @@ func (s *agent) init(ctx context.Context) {
}, },
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{ return &gossh.ServerConfig{
Config: gossh.Config{
// "arcfour" is the fastest SSH cipher. We prioritize throughput
// over encryption here, because the WebRTC connection is already
// encrypted. If possible, we'd disable encryption entirely here.
Ciphers: []string{"arcfour"},
},
NoClientAuth: true, NoClientAuth: true,
} }
}, },
} }
go s.run(ctx) go a.run(ctx)
} }
func (*agent) handleSSHSession(session ssh.Session) error { func (a *agent) handleSSHSession(session ssh.Session) error {
var ( var (
command string command string
args = []string{} args = []string{}
err error err error
) )
username := session.User()
if username == "" {
currentUser, err := user.Current() currentUser, err := user.Current()
if err != nil { if err != nil {
return xerrors.Errorf("get current user: %w", err) return xerrors.Errorf("get current user: %w", err)
} }
username = currentUser.Username username := currentUser.Username
}
// gliderlabs/ssh returns a command slice of zero // gliderlabs/ssh returns a command slice of zero
// when a shell is requested. // when a shell is requested.
@ -249,9 +240,9 @@ func (*agent) handleSSHSession(session ssh.Session) error {
} }
go func() { go func() {
for win := range windowSize { for win := range windowSize {
err := ptty.Resize(uint16(win.Width), uint16(win.Height)) err = ptty.Resize(uint16(win.Width), uint16(win.Height))
if err != nil { if err != nil {
panic(err) a.options.Logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
} }
} }
}() }()
@ -286,24 +277,24 @@ func (*agent) handleSSHSession(session ssh.Session) error {
} }
// isClosed returns whether the API is closed or not. // isClosed returns whether the API is closed or not.
func (s *agent) isClosed() bool { func (a *agent) isClosed() bool {
select { select {
case <-s.closed: case <-a.closed:
return true return true
default: default:
return false return false
} }
} }
func (s *agent) Close() error { func (a *agent) Close() error {
s.closeMutex.Lock() a.closeMutex.Lock()
defer s.closeMutex.Unlock() defer a.closeMutex.Unlock()
if s.isClosed() { if a.isClosed() {
return nil return nil
} }
close(s.closed) close(a.closed)
s.closeCancel() a.closeCancel()
_ = s.sshServer.Close() _ = a.sshServer.Close()
s.connCloseWait.Wait() a.connCloseWait.Wait()
return nil return nil
} }

View File

@ -39,9 +39,6 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
return nil, xerrors.Errorf("ssh: %w", err) return nil, xerrors.Errorf("ssh: %w", err)
} }
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{ sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
Config: ssh.Config{
Ciphers: []string{"arcfour"},
},
// SSH host validation isn't helpful, because obtaining a peer // SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace. // connection already signifies user-intent to dial a workspace.
// #nosec // #nosec

View File

@ -27,5 +27,5 @@ func Get(username string) (string, error) {
} }
return parts[6], nil return parts[6], nil
} }
return "", xerrors.New("user not found in /etc/passwd and $SHELL not set") return "", xerrors.Errorf("user %q not found in /etc/passwd", username)
} }

View File

@ -3,11 +3,11 @@ package cliui
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"sync" "sync"
"time" "time"
"github.com/briandowns/spinner" "github.com/briandowns/spinner"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
@ -21,7 +21,7 @@ type AgentOptions struct {
} }
// Agent displays a spinning indicator that waits for a workspace agent to connect. // Agent displays a spinning indicator that waits for a workspace agent to connect.
func Agent(cmd *cobra.Command, opts AgentOptions) error { func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
if opts.FetchInterval == 0 { if opts.FetchInterval == 0 {
opts.FetchInterval = 500 * time.Millisecond opts.FetchInterval = 500 * time.Millisecond
} }
@ -29,7 +29,7 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
opts.WarnInterval = 30 * time.Second opts.WarnInterval = 30 * time.Second
} }
var resourceMutex sync.Mutex var resourceMutex sync.Mutex
resource, err := opts.Fetch(cmd.Context()) resource, err := opts.Fetch(ctx)
if err != nil { if err != nil {
return xerrors.Errorf("fetch: %w", err) return xerrors.Errorf("fetch: %w", err)
} }
@ -40,7 +40,8 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
opts.WarnInterval = 0 opts.WarnInterval = 0
} }
spin := spinner.New(spinner.CharSets[78], 100*time.Millisecond, spinner.WithColor("fgHiGreen")) spin := spinner.New(spinner.CharSets[78], 100*time.Millisecond, spinner.WithColor("fgHiGreen"))
spin.Writer = cmd.OutOrStdout() spin.Writer = writer
spin.ForceOutput = true
spin.Suffix = " Waiting for connection from " + Styles.Field.Render(resource.Type+"."+resource.Name) + "..." spin.Suffix = " Waiting for connection from " + Styles.Field.Render(resource.Type+"."+resource.Name) + "..."
spin.Start() spin.Start()
defer spin.Stop() defer spin.Stop()
@ -51,7 +52,7 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
defer timer.Stop() defer timer.Stop()
go func() { go func() {
select { select {
case <-cmd.Context().Done(): case <-ctx.Done():
return return
case <-timer.C: case <-timer.C:
} }
@ -63,17 +64,17 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
} }
// This saves the cursor position, then defers clearing from the cursor // This saves the cursor position, then defers clearing from the cursor
// position to the end of the screen. // position to the end of the screen.
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\033[s\r\033[2K%s\n\n", Styles.Paragraph.Render(Styles.Prompt.String()+message)) _, _ = fmt.Fprintf(writer, "\033[s\r\033[2K%s\n\n", Styles.Paragraph.Render(Styles.Prompt.String()+message))
defer fmt.Fprintf(cmd.OutOrStdout(), "\033[u\033[J") defer fmt.Fprintf(writer, "\033[u\033[J")
}() }()
for { for {
select { select {
case <-cmd.Context().Done(): case <-ctx.Done():
return cmd.Context().Err() return ctx.Err()
case <-ticker.C: case <-ticker.C:
} }
resourceMutex.Lock() resourceMutex.Lock()
resource, err = opts.Fetch(cmd.Context()) resource, err = opts.Fetch(ctx)
if err != nil { if err != nil {
return xerrors.Errorf("fetch: %w", err) return xerrors.Errorf("fetch: %w", err)
} }

View File

@ -20,7 +20,7 @@ func TestAgent(t *testing.T) {
ptty := ptytest.New(t) ptty := ptytest.New(t)
cmd := &cobra.Command{ cmd := &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
err := cliui.Agent(cmd, cliui.AgentOptions{ err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{
WorkspaceName: "example", WorkspaceName: "example",
Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) {
resource := codersdk.WorkspaceResource{ resource := codersdk.WorkspaceResource{

38
cli/cliui/log.go Normal file
View File

@ -0,0 +1,38 @@
package cliui
import (
"fmt"
"io"
"strings"
"github.com/charmbracelet/lipgloss"
)
// cliMessage provides a human-readable message for CLI errors and messages.
type cliMessage struct {
Level string
Style lipgloss.Style
Header string
Lines []string
}
// String formats the CLI message for consumption by a human.
func (m cliMessage) String() string {
var str strings.Builder
_, _ = fmt.Fprintf(&str, "%s\r\n",
Styles.Bold.Render(m.Header))
for _, line := range m.Lines {
_, _ = fmt.Fprintf(&str, " %s %s\r\n", m.Style.Render("|"), line)
}
return str.String()
}
// Warn writes a log to the writer provided.
func Warn(wtr io.Writer, header string, lines ...string) {
_, _ = fmt.Fprint(wtr, cliMessage{
Level: "warning",
Style: Styles.Warn,
Header: header,
Lines: lines,
}.String())
}

View File

@ -2,6 +2,7 @@ package cliui
import ( import (
"bufio" "bufio"
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -62,7 +63,11 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) {
var rawMessage json.RawMessage var rawMessage json.RawMessage
err := json.NewDecoder(pipeReader).Decode(&rawMessage) err := json.NewDecoder(pipeReader).Decode(&rawMessage)
if err == nil { if err == nil {
line = string(rawMessage) var buf bytes.Buffer
err = json.Compact(&buf, rawMessage)
if err == nil {
line = buf.String()
}
} }
} }
} }

View File

@ -93,9 +93,7 @@ func TestPrompt(t *testing.T) {
ptty.WriteLine(`{ ptty.WriteLine(`{
"test": "wow" "test": "wow"
}`) }`)
require.Equal(t, `{ require.Equal(t, `{"test":"wow"}`, <-doneChan)
"test": "wow"
}`, <-doneChan)
}) })
} }

View File

@ -3,27 +3,27 @@ package cliui
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/spf13/cobra"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
func WorkspaceBuild(cmd *cobra.Command, client *codersdk.Client, build uuid.UUID, before time.Time) error { func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Client, build uuid.UUID, before time.Time) error {
return ProvisionerJob(cmd, ProvisionerJobOptions{ return ProvisionerJob(ctx, writer, ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build) build, err := client.WorkspaceBuild(ctx, build)
return build.Job, err return build.Job, err
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, error) {
return client.WorkspaceBuildLogsAfter(cmd.Context(), build, before) return client.WorkspaceBuildLogsAfter(ctx, build, before)
}, },
}) })
} }
@ -39,25 +39,25 @@ type ProvisionerJobOptions struct {
} }
// ProvisionerJob renders a provisioner job with interactive cancellation. // ProvisionerJob renders a provisioner job with interactive cancellation.
func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error { func ProvisionerJob(ctx context.Context, writer io.Writer, opts ProvisionerJobOptions) error {
if opts.FetchInterval == 0 { if opts.FetchInterval == 0 {
opts.FetchInterval = time.Second opts.FetchInterval = time.Second
} }
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
var ( var (
currentStage = "Queued" currentStage = "Queued"
currentStageStartedAt = time.Now().UTC() currentStageStartedAt = time.Now().UTC()
didLogBetweenStage = false didLogBetweenStage = false
ctx, cancelFunc = context.WithCancel(cmd.Context())
errChan = make(chan error, 1) errChan = make(chan error, 1)
job codersdk.ProvisionerJob job codersdk.ProvisionerJob
jobMutex sync.Mutex jobMutex sync.Mutex
) )
defer cancelFunc()
printStage := func() { printStage := func() {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), Styles.Prompt.Render("⧗")+"%s\n", Styles.Field.Render(currentStage)) _, _ = fmt.Fprintf(writer, Styles.Prompt.Render("⧗")+"%s\n", Styles.Field.Render(currentStage))
} }
updateStage := func(stage string, startedAt time.Time) { updateStage := func(stage string, startedAt time.Time) {
@ -70,7 +70,7 @@ func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error {
if job.CompletedAt != nil && job.Status != codersdk.ProvisionerJobSucceeded { if job.CompletedAt != nil && job.Status != codersdk.ProvisionerJobSucceeded {
mark = Styles.Crossmark mark = Styles.Crossmark
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), prefix+mark.String()+Styles.Placeholder.Render(" %s [%dms]")+"\n", currentStage, startedAt.Sub(currentStageStartedAt).Milliseconds()) _, _ = fmt.Fprintf(writer, prefix+mark.String()+Styles.Placeholder.Render(" %s [%dms]")+"\n", currentStage, startedAt.Sub(currentStageStartedAt).Milliseconds())
} }
if stage == "" { if stage == "" {
return return
@ -116,7 +116,7 @@ func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error {
return return
} }
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\033[2K\r\n"+Styles.FocusedPrompt.String()+Styles.Bold.Render("Gracefully canceling...")+"\n\n") _, _ = fmt.Fprintf(writer, "\033[2K\r\n"+Styles.FocusedPrompt.String()+Styles.Bold.Render("Gracefully canceling...")+"\n\n")
err := opts.Cancel() err := opts.Cancel()
if err != nil { if err != nil {
errChan <- xerrors.Errorf("cancel: %w", err) errChan <- xerrors.Errorf("cancel: %w", err)
@ -183,7 +183,7 @@ func ProvisionerJob(cmd *cobra.Command, opts ProvisionerJobOptions) error {
jobMutex.Unlock() jobMutex.Unlock()
continue continue
} }
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s %s\n", Styles.Placeholder.Render(" "), output) _, _ = fmt.Fprintf(writer, "%s %s\n", Styles.Placeholder.Render(" "), output)
didLogBetweenStage = true didLogBetweenStage = true
jobMutex.Unlock() jobMutex.Unlock()
} }

View File

@ -126,7 +126,7 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
logs := make(chan codersdk.ProvisionerJobLog, 1) logs := make(chan codersdk.ProvisionerJobLog, 1)
cmd := &cobra.Command{ cmd := &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ return cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
FetchInterval: time.Millisecond, FetchInterval: time.Millisecond,
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
jobLock.Lock() jobLock.Lock()

View File

@ -1,26 +1,162 @@
package cli package cli
import ( import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"github.com/cli/safeexec"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
) )
// const sshStartToken = "# ------------START-CODER-----------" const sshStartToken = "# ------------START-CODER-----------"
// const sshStartMessage = `# This was generated by "coder config-ssh". const sshStartMessage = `# This was generated by "coder config-ssh".
// # #
// # To remove this blob, run: # To remove this blob, run:
// # #
// # coder config-ssh --remove # coder config-ssh --remove
// # #
// # You should not hand-edit this section, unless you are deleting it.` # You should not hand-edit this section, unless you are deleting it.`
// const sshEndToken = "# ------------END-CODER------------" const sshEndToken = "# ------------END-CODER------------"
func configSSH() *cobra.Command { func configSSH() *cobra.Command {
var (
sshConfigFile string
)
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "config-ssh", Use: "config-ssh",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
client, err := createClient(cmd)
if err != nil {
return err
}
if strings.HasPrefix(sshConfigFile, "~/") {
dirname, _ := os.UserHomeDir()
sshConfigFile = filepath.Join(dirname, sshConfigFile[2:])
}
// Doesn't matter if this fails, because we write the file anyways.
sshConfigContentRaw, _ := os.ReadFile(sshConfigFile)
sshConfigContent := string(sshConfigContentRaw)
startIndex := strings.Index(sshConfigContent, sshStartToken)
endIndex := strings.Index(sshConfigContent, sshEndToken)
if startIndex != -1 && endIndex != -1 {
sshConfigContent = sshConfigContent[:startIndex-1] + sshConfigContent[endIndex+len(sshEndToken):]
}
workspaces, err := client.WorkspacesByUser(cmd.Context(), "")
if err != nil {
return err
}
if len(workspaces) == 0 {
return xerrors.New("You don't have any workspaces!")
}
binPath, err := currentBinPath(cmd)
if err != nil {
return err
}
sshConfigContent += "\n" + sshStartToken + "\n" + sshStartMessage + "\n\n"
sshConfigContentMutex := sync.Mutex{}
var errGroup errgroup.Group
for _, workspace := range workspaces {
workspace := workspace
errGroup.Go(func() error {
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
if err != nil {
return err
}
resourcesWithAgents := make([]codersdk.WorkspaceResource, 0)
for _, resource := range resources {
if resource.Agent == nil {
continue
}
resourcesWithAgents = append(resourcesWithAgents, resource)
}
sshConfigContentMutex.Lock()
defer sshConfigContentMutex.Unlock()
if len(resourcesWithAgents) == 1 {
sshConfigContent += strings.Join([]string{
"Host coder." + workspace.Name,
"\tHostName coder." + workspace.Name,
fmt.Sprintf("\tProxyCommand %q ssh --stdio %s", binPath, workspace.Name),
"\tConnectTimeout=0",
"\tStrictHostKeyChecking=no",
}, "\n") + "\n"
}
return nil
})
}
err = errGroup.Wait()
if err != nil {
return err
}
sshConfigContent += "\n" + sshEndToken
err = os.MkdirAll(filepath.Dir(sshConfigFile), os.ModePerm)
if err != nil {
return err
}
err = os.WriteFile(sshConfigFile, []byte(sshConfigContent), os.ModePerm)
if err != nil {
return err
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "An auto-generated ssh config was written to %q\n", sshConfigFile)
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "You should now be able to ssh into your workspace")
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "For example, try running\n\n\t$ ssh coder.%s\n\n", workspaces[0].Name)
return nil return nil
}, },
} }
cliflag.StringVarP(cmd.Flags(), &sshConfigFile, "ssh-config-file", "", "CODER_SSH_CONFIG_FILE", "~/.ssh/config", "Specifies the path to an SSH config.")
return cmd return cmd
} }
// currentBinPath returns the path to the coder binary suitable for use in ssh
// ProxyCommand.
func currentBinPath(cmd *cobra.Command) (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", xerrors.Errorf("get executable path: %w", err)
}
binName := filepath.Base(exePath)
// We use safeexec instead of os/exec because os/exec returns paths in
// the current working directory, which we will run into very often when
// looking for our own path.
pathPath, err := safeexec.LookPath(binName)
// On Windows, the coder-cli executable must be in $PATH for both Msys2/Git
// Bash and OpenSSH for Windows (used by Powershell and VS Code) to function
// correctly. Check if the current executable is in $PATH, and warn the user
// if it isn't.
if err != nil && runtime.GOOS == "windows" {
cliui.Warn(cmd.OutOrStdout(),
"The current executable is not in $PATH.",
"This may lead to problems connecting to your workspace via SSH.",
fmt.Sprintf("Please move %q to a location in your $PATH (such as System32) and run `%s config-ssh` again.", binName, binName),
)
// Return the exePath so SSH at least works outside of Msys2.
return exePath, nil
}
// Warn the user if the current executable is not the same as the one in
// $PATH.
if filepath.Clean(pathPath) != filepath.Clean(exePath) {
cliui.Warn(cmd.OutOrStdout(),
"The current executable path does not match the executable path found in $PATH.",
"This may cause issues connecting to your workspace via SSH.",
fmt.Sprintf("\tCurrent executable path: %q", exePath),
fmt.Sprintf("\tExecutable path in $PATH: %q", pathPath),
)
}
return binName, nil
}

41
cli/configssh_test.go Normal file
View File

@ -0,0 +1,41 @@
package cli_test
import (
"os"
"testing"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/pty/ptytest"
"github.com/stretchr/testify/require"
)
func TestConfigSSH(t *testing.T) {
t.Parallel()
t.Run("Create", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdtest.NewProvisionerDaemon(t, client)
version := coderdtest.CreateProjectVersion(t, client, user.OrganizationID, nil)
coderdtest.AwaitProjectVersionJob(t, client, version.ID)
project := coderdtest.CreateProject(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, "", project.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
tempFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
_ = tempFile.Close()
cmd, root := clitest.New(t, "config-ssh", "--ssh-config-file", tempFile.Name())
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
}()
<-doneChan
})
}

View File

@ -125,7 +125,7 @@ func createValidProjectVersion(cmd *cobra.Command, client *codersdk.Client, orga
return nil, nil, err return nil, nil, err
} }
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
version, err := client.ProjectVersion(cmd.Context(), version.ID) version, err := client.ProjectVersion(cmd.Context(), version.ID)
return version.Job, err return version.Job, err

View File

@ -2,18 +2,28 @@ package cli
import ( import (
"context" "context"
"io"
"net"
"os"
"time"
"github.com/mattn/go-isatty"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
"github.com/spf13/cobra" "github.com/spf13/cobra"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/cliui"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"golang.org/x/crypto/ssh/terminal"
) )
func ssh() *cobra.Command { func ssh() *cobra.Command {
var (
stdio bool
)
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "ssh <workspace> [resource]", Use: "ssh <workspace> [resource]",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
@ -25,8 +35,11 @@ func ssh() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
if workspace.LatestBuild.Transition != database.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to ssh")
}
if workspace.LatestBuild.Job.CompletedAt == nil { if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd, client, workspace.LatestBuild.ID, workspace.CreatedAt) err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil { if err != nil {
return err return err
} }
@ -67,7 +80,9 @@ func ssh() *cobra.Command {
} }
return xerrors.Errorf("no sshable agent with address %q: %+v", resourceAddress, resourceKeys) return xerrors.Errorf("no sshable agent with address %q: %+v", resourceAddress, resourceKeys)
} }
err = cliui.Agent(cmd, cliui.AgentOptions{ // OpenSSH passes stderr directly to the calling TTY.
// This is required in "stdio" mode so a connecting indicator can be displayed.
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) {
return client.WorkspaceResource(ctx, resource.ID) return client.WorkspaceResource(ctx, resource.ID)
@ -84,6 +99,17 @@ func ssh() *cobra.Command {
return err return err
} }
defer conn.Close() defer conn.Close()
if stdio {
rawSSH, err := conn.SSH()
if err != nil {
return err
}
go func() {
_, _ = io.Copy(cmd.OutOrStdout(), rawSSH)
}()
_, _ = io.Copy(rawSSH, cmd.InOrStdin())
return nil
}
sshClient, err := conn.SSHClient() sshClient, err := conn.SSHClient()
if err != nil { if err != nil {
return err return err
@ -94,9 +120,17 @@ func ssh() *cobra.Command {
return err return err
} }
err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{ if isatty.IsTerminal(os.Stdout.Fd()) {
gossh.OCRNL: 1, state, err := terminal.MakeRaw(int(os.Stdin.Fd()))
}) if err != nil {
return err
}
defer func() {
_ = terminal.Restore(int(os.Stdin.Fd()), state)
}()
}
err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{})
if err != nil { if err != nil {
return err return err
} }
@ -115,6 +149,36 @@ func ssh() *cobra.Command {
return nil return nil
}, },
} }
cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.")
return cmd return cmd
} }
type stdioConn struct {
io.Reader
io.Writer
}
func (*stdioConn) Close() (err error) {
return nil
}
func (*stdioConn) LocalAddr() net.Addr {
return nil
}
func (*stdioConn) RemoteAddr() net.Addr {
return nil
}
func (*stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@ -1,10 +1,15 @@
package cli_test package cli_test
import ( import (
"io"
"net"
"runtime"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
@ -78,4 +83,113 @@ func TestSSH(t *testing.T) {
pty.WriteLine("exit") pty.WriteLine("exit")
<-doneChan <-doneChan
}) })
t.Run("Stdio", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdtest.NewProvisionerDaemon(t, client)
agentToken := uuid.NewString()
version := coderdtest.CreateProjectVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "dev",
Type: "google_compute_instance",
Agent: &proto.Agent{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: agentToken,
},
},
}},
},
},
}},
})
coderdtest.AwaitProjectVersionJob(t, client, version.ID)
project := coderdtest.CreateProject(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, "", project.ID)
go func() {
// Run this async so the SSH command has to wait for
// the build and agent to connect!
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
}()
clientOutput, clientInput := io.Pipe()
serverOutput, serverInput := io.Pipe()
cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
cmd.SetIn(clientOutput)
cmd.SetOut(serverInput)
cmd.SetErr(io.Discard)
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
}()
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
Reader: serverOutput,
Writer: clientInput,
}, "", &ssh.ClientConfig{
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
sshClient := ssh.NewClient(conn, channels, requests)
session, err := sshClient.NewSession()
require.NoError(t, err)
command := "sh -c exit"
if runtime.GOOS == "windows" {
command = "cmd.exe /c exit"
}
err = session.Run(command)
require.NoError(t, err)
err = sshClient.Close()
require.NoError(t, err)
_ = clientOutput.Close()
<-doneChan
})
}
type stdioConn struct {
io.Reader
io.Writer
}
func (*stdioConn) Close() (err error) {
return nil
}
func (*stdioConn) LocalAddr() net.Addr {
return nil
}
func (*stdioConn) RemoteAddr() net.Addr {
return nil
}
func (*stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (*stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
} }

View File

@ -266,7 +266,7 @@ func start() *cobra.Command {
return xerrors.Errorf("delete workspace: %w", err) return xerrors.Errorf("delete workspace: %w", err)
} }
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID) build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err return build.Job, err
@ -313,7 +313,7 @@ func start() *cobra.Command {
cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder") cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder")
cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard") cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
// systemd uses the CACHE_DIRECTORY environment variable! // systemd uses the CACHE_DIRECTORY environment variable!
cliflag.StringVarP(root.Flags(), &cacheDir, "cache-dir", "", "CACHE_DIRECTORY", filepath.Join(os.TempDir(), ".coder-cache"), "Specifies a directory to cache binaries for provision operations.") cliflag.StringVarP(root.Flags(), &cacheDir, "cache-dir", "", "CACHE_DIRECTORY", filepath.Join(os.TempDir(), "coder-cache"), "Specifies a directory to cache binaries for provision operations.")
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering") cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to") cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.") cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
@ -369,6 +369,11 @@ func createFirstUser(cmd *cobra.Command, client *codersdk.Client, cfg config.Roo
} }
func newProvisionerDaemon(ctx context.Context, client *codersdk.Client, logger slog.Logger, cacheDir string) (*provisionerd.Server, error) { func newProvisionerDaemon(ctx context.Context, client *codersdk.Client, logger slog.Logger, cacheDir string) (*provisionerd.Server, error) {
err := os.MkdirAll(cacheDir, 0700)
if err != nil {
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
}
terraformClient, terraformServer := provisionersdk.TransportPipe() terraformClient, terraformServer := provisionersdk.TransportPipe()
go func() { go func() {
err := terraform.Serve(ctx, &terraform.ServeOptions{ err := terraform.Serve(ctx, &terraform.ServeOptions{

View File

@ -145,7 +145,7 @@ func workspaceCreate() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), workspace.LatestBuild.ID) build, err := client.WorkspaceBuild(cmd.Context(), workspace.LatestBuild.ID)
return build.Job, err return build.Job, err

View File

@ -32,7 +32,7 @@ func workspaceDelete() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID) build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err return build.Job, err

View File

@ -31,7 +31,7 @@ func workspaceStart() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID) build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err return build.Job, err

View File

@ -31,7 +31,7 @@ func workspaceStop() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
err = cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
build, err := client.WorkspaceBuild(cmd.Context(), build.ID) build, err := client.WorkspaceBuild(cmd.Context(), build.ID)
return build.Job, err return build.Job, err

View File

@ -96,7 +96,7 @@ func main() {
job.Status = codersdk.ProvisionerJobSucceeded job.Status = codersdk.ProvisionerJobSucceeded
}() }()
err := cliui.ProvisionerJob(cmd, cliui.ProvisionerJobOptions{ err := cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
return job, nil return job, nil
}, },
@ -172,7 +172,7 @@ func main() {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
resource.Agent.Status = codersdk.WorkspaceAgentConnected resource.Agent.Status = codersdk.WorkspaceAgentConnected
}() }()
err := cliui.Agent(cmd, cliui.AgentOptions{ err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{
WorkspaceName: "dev", WorkspaceName: "dev",
Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) { Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) {
return resource, nil return resource, nil

View File

@ -137,8 +137,6 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
return return
} }
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
defer func() { defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "") _ = conn.Close(websocket.StatusNormalClosure, "")
}() }()
@ -183,6 +181,23 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
} }
return nil return nil
} }
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDWithoutAfter(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID.String() != latestBuild.ID.String() {
return xerrors.New("build is outdated")
}
return nil
}
defer func() { defer func() {
disconnectedAt = sql.NullTime{ disconnectedAt = sql.NullTime{
@ -197,6 +212,13 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) _ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return return
} }
err = ensureLatestBuild()
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop() defer ticker.Stop()
@ -214,6 +236,12 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) _ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return return
} }
err = ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
} }
} }
} }

View File

@ -22,7 +22,7 @@ Coder requires a Google Cloud Service Account to provision workspaces.
- Service Account User - Service Account User
3. Click on the created key, and navigate to the "Keys" tab. 3. Click on the created key, and navigate to the "Keys" tab.
4. Click "Add key", then "Create new key". 4. Click "Add key", then "Create new key".
5. Generate a JSON private key, and paste the contents in \'\' quotes below. 5. Generate a JSON private key, and paste the contents below.
EOF EOF
sensitive = true sensitive = true
} }

View File

@ -22,7 +22,7 @@ Coder requires a Google Cloud Service Account to provision workspaces.
- Service Account User - Service Account User
3. Click on the created key, and navigate to the "Keys" tab. 3. Click on the created key, and navigate to the "Keys" tab.
4. Click "Add key", then "Create new key". 4. Click "Add key", then "Create new key".
5. Generate a JSON private key, and paste the contents in \'\' quotes below. 5. Generate a JSON private key, and paste the contents below.
EOF EOF
sensitive = true sensitive = true
} }

5
go.mod
View File

@ -14,6 +14,9 @@ replace github.com/hashicorp/terraform-config-inspect => github.com/kylecarbs/te
// Required until https://github.com/chzyer/readline/pull/198 is merged. // Required until https://github.com/chzyer/readline/pull/198 is merged.
replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8 replace github.com/chzyer/readline => github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8
// Required until https://github.com/briandowns/spinner/pull/136 is merged.
replace github.com/briandowns/spinner => github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e
// opencensus-go leaks a goroutine by default. // opencensus-go leaks a goroutine by default.
replace go.opencensus.io => github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b replace go.opencensus.io => github.com/kylecarbs/opencensus-go v0.23.1-0.20220307014935-4d0325a68f8b
@ -90,6 +93,8 @@ require (
storj.io/drpc v0.0.30 storj.io/drpc v0.0.30
) )
require github.com/cli/safeexec v1.0.0
require ( require (
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/BurntSushi/toml v1.0.0 // indirect github.com/BurntSushi/toml v1.0.0 // indirect

6
go.sum
View File

@ -254,8 +254,6 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dR
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g=
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA=
github.com/briandowns/spinner v1.18.1 h1:yhQmQtM1zsqFsouh09Bk/jCjd50pC3EOGsh28gLVvwY=
github.com/briandowns/spinner v1.18.1/go.mod h1:mQak9GHqbspjC/5iUx3qMlIho8xBS/ppAL/hX5SmPJU=
github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk= github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk=
github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
@ -304,6 +302,8 @@ github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6D
github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I=
github.com/clbanning/mxj/v2 v2.5.5 h1:oT81vUeEiQQ/DcHbzSytRngP6Ky9O+L+0Bw0zSJag9E= github.com/clbanning/mxj/v2 v2.5.5 h1:oT81vUeEiQQ/DcHbzSytRngP6Ky9O+L+0Bw0zSJag9E=
github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= github.com/clbanning/mxj/v2 v2.5.5/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s=
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93 h1:QrGfkZDnMxcWHaYDdB7CmqS9i26OAnUj/xcus/abYkY= github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93 h1:QrGfkZDnMxcWHaYDdB7CmqS9i26OAnUj/xcus/abYkY=
github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93/go.mod h1:QiTe66jFdP7cUKMCCf/WrvDyYdtdmdZfVcdoLbzaKVY= github.com/cloudflare/brotli-go v0.0.0-20191101163834-d34379f7ff93/go.mod h1:QiTe66jFdP7cUKMCCf/WrvDyYdtdmdZfVcdoLbzaKVY=
@ -1126,6 +1126,8 @@ github.com/kylecarbs/promptui v0.8.1-0.20201231190244-d8f2159af2b2 h1:MUREBTh4ky
github.com/kylecarbs/promptui v0.8.1-0.20201231190244-d8f2159af2b2/go.mod h1:n4zTdgP0vr0S3w7/O/g98U+e0gwLScEXGwov2nIKuGQ= github.com/kylecarbs/promptui v0.8.1-0.20201231190244-d8f2159af2b2/go.mod h1:n4zTdgP0vr0S3w7/O/g98U+e0gwLScEXGwov2nIKuGQ=
github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8 h1:Y7O3Z3YeNRtw14QrtHpevU4dSjCkov0J40MtQ7Nc0n8= github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8 h1:Y7O3Z3YeNRtw14QrtHpevU4dSjCkov0J40MtQ7Nc0n8=
github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8/go.mod h1:n/KX1BZoN1m9EwoXkn/xAV4fd3k8c++gGBsgLONaPOY= github.com/kylecarbs/readline v0.0.0-20220211054233-0d62993714c8/go.mod h1:n/KX1BZoN1m9EwoXkn/xAV4fd3k8c++gGBsgLONaPOY=
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e h1:OP0ZMFeZkUnOzTFRfpuK3m7Kp4fNvC6qN+exwj7aI4M=
github.com/kylecarbs/spinner v1.18.2-0.20220329160715-20702b5af89e/go.mod h1:mQak9GHqbspjC/5iUx3qMlIho8xBS/ppAL/hX5SmPJU=
github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88 h1:tvG/qs5c4worwGyGnbbb4i/dYYLjpFwDMqcIT3awAf8= github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88 h1:tvG/qs5c4worwGyGnbbb4i/dYYLjpFwDMqcIT3awAf8=
github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88/go.mod h1:Z0Nnk4+3Cy89smEbrq+sl1bxc9198gIP4I7wcQF6Kqs= github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88/go.mod h1:Z0Nnk4+3Cy89smEbrq+sl1bxc9198gIP4I7wcQF6Kqs=
github.com/kylecarbs/terraform-exec v0.15.1-0.20220202050609-a1ce7181b180 h1:yafC0pmxjs18fnO5RdKFLSItJIjYwGfSHTfcUvlZb3E= github.com/kylecarbs/terraform-exec v0.15.1-0.20220202050609-a1ce7181b180 h1:yafC0pmxjs18fnO5RdKFLSItJIjYwGfSHTfcUvlZb3E=

View File

@ -9,6 +9,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"github.com/awalterschulze/gographviz" "github.com/awalterschulze/gographviz"
@ -87,7 +88,9 @@ func (t *terraform) Provision(stream proto.DRPCProvisioner_ProvisionStream) erro
}) })
} }
}() }()
if t.cachePath != "" { // Windows doesn't work with a plugin cache directory.
// The cause is unknown, but it should work.
if t.cachePath != "" && runtime.GOOS != "windows" {
err = terraform.SetEnv(map[string]string{ err = terraform.SetEnv(map[string]string{
"TF_PLUGIN_CACHE_DIR": t.cachePath, "TF_PLUGIN_CACHE_DIR": t.cachePath,
}) })

View File

@ -2,13 +2,14 @@ package terraform
import ( import (
"context" "context"
"os/exec" "path/filepath"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/cli/safeexec"
"github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk"
"github.com/hashicorp/hc-install/product" "github.com/hashicorp/hc-install/product"
@ -41,7 +42,7 @@ type ServeOptions struct {
// Serve starts a dRPC server on the provided transport speaking Terraform provisioner. // Serve starts a dRPC server on the provided transport speaking Terraform provisioner.
func Serve(ctx context.Context, options *ServeOptions) error { func Serve(ctx context.Context, options *ServeOptions) error {
if options.BinaryPath == "" { if options.BinaryPath == "" {
binaryPath, err := exec.LookPath("terraform") binaryPath, err := safeexec.LookPath("terraform")
if err != nil { if err != nil {
installer := &releases.ExactVersion{ installer := &releases.ExactVersion{
InstallDir: options.CachePath, InstallDir: options.CachePath,
@ -55,7 +56,16 @@ func Serve(ctx context.Context, options *ServeOptions) error {
} }
options.BinaryPath = execPath options.BinaryPath = execPath
} else { } else {
options.BinaryPath = binaryPath // If the "coder" binary is in the same directory as
// the "terraform" binary, "terraform" is returned.
//
// We must resolve the absolute path for other processes
// to execute this properly!
absoluteBinary, err := filepath.Abs(binaryPath)
if err != nil {
return xerrors.Errorf("absolute: %w", err)
}
options.BinaryPath = absoluteBinary
} }
} }
return provisionersdk.Serve(ctx, &terraform{ return provisionersdk.Serve(ctx, &terraform{