fix: add agent exec abstraction (#15717)

This commit is contained in:
Jon Ayers
2024-12-04 23:30:25 +02:00
committed by GitHub
parent 6c9ccca687
commit ce573b9faa
16 changed files with 210 additions and 192 deletions

View File

@ -33,6 +33,7 @@ import (
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts" "github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/agent/proto"
@ -80,6 +81,7 @@ type Options struct {
ReportMetadataInterval time.Duration ReportMetadataInterval time.Duration
ServiceBannerRefreshInterval time.Duration ServiceBannerRefreshInterval time.Duration
BlockFileTransfer bool BlockFileTransfer bool
Execer agentexec.Execer
} }
type Client interface { type Client interface {
@ -139,6 +141,10 @@ func New(options Options) Agent {
prometheusRegistry = prometheus.NewRegistry() prometheusRegistry = prometheus.NewRegistry()
} }
if options.Execer == nil {
options.Execer = agentexec.DefaultExecer
}
hardCtx, hardCancel := context.WithCancel(context.Background()) hardCtx, hardCancel := context.WithCancel(context.Background())
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx) gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
a := &agent{ a := &agent{
@ -171,6 +177,7 @@ func New(options Options) Agent {
prometheusRegistry: prometheusRegistry, prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry), metrics: newAgentMetrics(prometheusRegistry),
execer: options.Execer,
} }
// Initially, we have a closed channel, reflecting the fact that we are not initially connected. // Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one // Each time we connect we replace the channel (while holding the closeMutex) with a new one
@ -239,6 +246,7 @@ type agent struct {
// metrics are prometheus registered metrics that will be collected and // metrics are prometheus registered metrics that will be collected and
// labeled in Coder with the agent + workspace. // labeled in Coder with the agent + workspace.
metrics *agentMetrics metrics *agentMetrics
execer agentexec.Execer
} }
func (a *agent) TailnetConn() *tailnet.Conn { func (a *agent) TailnetConn() *tailnet.Conn {
@ -247,7 +255,7 @@ func (a *agent) TailnetConn() *tailnet.Conn {
func (a *agent) init() { func (a *agent) init() {
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown. // pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{ sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout, MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile }, MOTDFile: func() string { return a.manifest.Load().MOTDFile },
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() }, AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },

View File

@ -17,9 +17,6 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
) )
// unset is set to an invalid value for nice and oom scores.
const unset = -2000
// CLI runs the agent-exec command. It should only be called by the cli package. // CLI runs the agent-exec command. It should only be called by the cli package.
func CLI() error { func CLI() error {
// We lock the OS thread here to avoid a race condition where the nice priority // We lock the OS thread here to avoid a race condition where the nice priority

View File

@ -20,60 +20,101 @@ const (
EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT" EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
EnvProcOOMScore = "CODER_PROC_OOM_SCORE" EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
EnvProcNiceScore = "CODER_PROC_NICE_SCORE" EnvProcNiceScore = "CODER_PROC_NICE_SCORE"
// unset is set to an invalid value for nice and oom scores.
unset = -2000
) )
// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing var DefaultExecer Execer = execer{}
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
// is returned. All instances of exec.Cmd should flow through this function to ensure // Execer defines an abstraction for creating exec.Cmd variants. It's unfortunately
// proper resource constraints are applied to the child process. // necessary because we need to be able to wrap child processes with "coder agent-exec"
func CommandContext(ctx context.Context, cmd string, args ...string) (*exec.Cmd, error) { // for templates that expect the agent to manage process priority.
cmd, args, err := agentExecCmd(cmd, args...) type Execer interface {
if err != nil { // CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
return nil, xerrors.Errorf("agent exec cmd: %w", err) // the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
} // is returned. All instances of exec.Cmd should flow through this function to ensure
return exec.CommandContext(ctx, cmd, args...), nil // proper resource constraints are applied to the child process.
CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
// is returned. All instances of pty.Cmd should flow through this function to ensure
// proper resource constraints are applied to the child process.
PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd
} }
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing func NewExecer() (Execer, error) {
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
// is returned. All instances of pty.Cmd should flow through this function to ensure
// proper resource constraints are applied to the child process.
func PTYCommandContext(ctx context.Context, cmd string, args ...string) (*pty.Cmd, error) {
cmd, args, err := agentExecCmd(cmd, args...)
if err != nil {
return nil, xerrors.Errorf("agent exec cmd: %w", err)
}
return pty.CommandContext(ctx, cmd, args...), nil
}
func agentExecCmd(cmd string, args ...string) (string, []string, error) {
_, enabled := os.LookupEnv(EnvProcPrioMgmt) _, enabled := os.LookupEnv(EnvProcPrioMgmt)
if runtime.GOOS != "linux" || !enabled { if runtime.GOOS != "linux" || !enabled {
return cmd, args, nil return DefaultExecer, nil
} }
executable, err := os.Executable() executable, err := os.Executable()
if err != nil { if err != nil {
return "", nil, xerrors.Errorf("get executable: %w", err) return nil, xerrors.Errorf("get executable: %w", err)
} }
bin, err := filepath.EvalSymlinks(executable) bin, err := filepath.EvalSymlinks(executable)
if err != nil { if err != nil {
return "", nil, xerrors.Errorf("eval symlinks: %w", err) return nil, xerrors.Errorf("eval symlinks: %w", err)
} }
oomScore, ok := envValInt(EnvProcOOMScore)
if !ok {
oomScore = unset
}
niceScore, ok := envValInt(EnvProcNiceScore)
if !ok {
niceScore = unset
}
return priorityExecer{
binPath: bin,
oomScore: oomScore,
niceScore: niceScore,
}, nil
}
type execer struct{}
func (execer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, cmd, args...)
}
func (execer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
return pty.CommandContext(ctx, cmd, args...)
}
type priorityExecer struct {
binPath string
oomScore int
niceScore int
}
func (e priorityExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
cmd, args = e.agentExecCmd(cmd, args...)
return exec.CommandContext(ctx, cmd, args...)
}
func (e priorityExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
cmd, args = e.agentExecCmd(cmd, args...)
return pty.CommandContext(ctx, cmd, args...)
}
func (e priorityExecer) agentExecCmd(cmd string, args ...string) (string, []string) {
execArgs := []string{"agent-exec"} execArgs := []string{"agent-exec"}
if score, ok := envValInt(EnvProcOOMScore); ok { if e.oomScore != unset {
execArgs = append(execArgs, oomScoreArg(score)) execArgs = append(execArgs, oomScoreArg(e.oomScore))
} }
if score, ok := envValInt(EnvProcNiceScore); ok { if e.niceScore != unset {
execArgs = append(execArgs, niceScoreArg(score)) execArgs = append(execArgs, niceScoreArg(e.niceScore))
} }
execArgs = append(execArgs, "--", cmd) execArgs = append(execArgs, "--", cmd)
execArgs = append(execArgs, args...) execArgs = append(execArgs, args...)
return bin, execArgs, nil return e.binPath, execArgs
} }
// envValInt searches for a key in a list of environment variables and parses it to an int. // envValInt searches for a key in a list of environment variables and parses it to an int.

View File

@ -0,0 +1,84 @@
package agentexec
import (
"context"
"os/exec"
"testing"
"github.com/stretchr/testify/require"
)
func TestExecer(t *testing.T) {
t.Parallel()
t.Run("Default", func(t *testing.T) {
t.Parallel()
cmd := DefaultExecer.CommandContext(context.Background(), "sh", "-c", "sleep")
path, err := exec.LookPath("sh")
require.NoError(t, err)
require.Equal(t, path, cmd.Path)
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Priority", func(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: unset,
niceScore: unset,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Nice", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: unset,
niceScore: 10,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("OOM", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: 123,
niceScore: unset,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Both", func(t *testing.T) {
t.Parallel()
e := priorityExecer{
binPath: "/foo/bar/baz",
oomScore: 432,
niceScore: 14,
}
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
require.Equal(t, e.binPath, cmd.Path)
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
})
})
}

View File

@ -1,119 +0,0 @@
package agentexec_test
import (
"context"
"os"
"os/exec"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentexec"
)
//nolint:paralleltest // we need to test environment variables
func TestExec(t *testing.T) {
//nolint:paralleltest // we need to test environment variables
t.Run("NonLinux", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "true")
if runtime.GOOS == "linux" {
t.Skip("skipping on linux")
}
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
path, err := exec.LookPath("sh")
require.NoError(t, err)
require.Equal(t, path, cmd.Path)
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
})
//nolint:paralleltest // we need to test environment variables
t.Run("Linux", func(t *testing.T) {
//nolint:paralleltest // we need to test environment variables
t.Run("Disabled", func(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
path, err := exec.LookPath("sh")
require.NoError(t, err)
require.Equal(t, path, cmd.Path)
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
})
//nolint:paralleltest // we need to test environment variables
t.Run("Enabled", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Nice", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
t.Setenv(agentexec.EnvProcNiceScore, "10")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("OOM", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
t.Setenv(agentexec.EnvProcOOMScore, "123")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
})
t.Run("Both", func(t *testing.T) {
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
t.Setenv(agentexec.EnvProcOOMScore, "432")
t.Setenv(agentexec.EnvProcNiceScore, "14")
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
executable, err := os.Executable()
require.NoError(t, err)
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
require.NoError(t, err)
require.Equal(t, executable, cmd.Path)
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
})
})
}

View File

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentscripts" "github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/agent/agenttest"
@ -160,7 +161,7 @@ func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscript
} }
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
logger := testutil.Logger(t) logger := testutil.Logger(t)
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil) s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
_ = s.Close() _ = s.Close()

View File

@ -98,6 +98,7 @@ type Server struct {
// a lock on mu but protected by closing. // a lock on mu but protected by closing.
wg sync.WaitGroup wg sync.WaitGroup
Execer agentexec.Execer
logger slog.Logger logger slog.Logger
srv *ssh.Server srv *ssh.Server
@ -110,7 +111,7 @@ type Server struct {
metrics *sshServerMetrics metrics *sshServerMetrics
} }
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) { func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) {
// Clients' should ignore the host key when connecting. // Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH, // The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security. // so SSH authentication doesn't improve security.
@ -153,6 +154,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
metrics := newSSHServerMetrics(prometheusRegistry) metrics := newSSHServerMetrics(prometheusRegistry)
s := &Server{ s := &Server{
Execer: execer,
listeners: make(map[net.Listener]struct{}), listeners: make(map[net.Listener]struct{}),
fs: fs, fs: fs,
conns: make(map[net.Conn]struct{}), conns: make(map[net.Conn]struct{}),
@ -726,10 +728,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
} }
} }
cmd, err := agentexec.PTYCommandContext(ctx, name, args...) cmd := s.Execer.PTYCommandContext(ctx, name, args...)
if err != nil {
return nil, xerrors.Errorf("pty command context: %w", err)
}
cmd.Dir = s.config.WorkingDirectory() cmd.Dir = s.config.WorkingDirectory()
// If the metadata directory doesn't exist, we run the command // If the metadata directory doesn't exist, we run the command

View File

@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/pty" "github.com/coder/coder/v2/pty"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -35,7 +36,7 @@ func Test_sessionStart_orphan(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel() defer cancel()
logger := testutil.Logger(t) logger := testutil.Logger(t)
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()

View File

@ -22,6 +22,7 @@ import (
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/pty/ptytest" "github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
@ -36,7 +37,7 @@ func TestNewServer_ServeClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := testutil.Logger(t) logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@ -77,7 +78,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := testutil.Logger(t) logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
_ = s.Close() _ = s.Close()
@ -108,7 +109,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@ -159,7 +160,7 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := testutil.Logger(t) logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@ -224,7 +225,7 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := testutil.Logger(t) logger := testutil.Logger(t)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil) s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()

View File

@ -21,6 +21,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -34,7 +35,7 @@ func TestServer_X11(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := testutil.Logger(t) logger := testutil.Logger(t)
fs := afero.NewOsFs() fs := afero.NewOsFs()
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{}) s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()

View File

@ -40,7 +40,7 @@ type bufferedReconnectingPTY struct {
// newBuffered starts the buffered pty. If the context ends the process will be // newBuffered starts the buffered pty. If the context ends the process will be
// killed. // killed.
func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *bufferedReconnectingPTY { func newBuffered(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *bufferedReconnectingPTY {
rpty := &bufferedReconnectingPTY{ rpty := &bufferedReconnectingPTY{
activeConns: map[string]net.Conn{}, activeConns: map[string]net.Conn{},
command: cmd, command: cmd,
@ -59,11 +59,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
// Add TERM then start the command with a pty. pty.Cmd duplicates Path as the // Add TERM then start the command with a pty. pty.Cmd duplicates Path as the
// first argument so remove it. // first argument so remove it.
cmdWithEnv, err := agentexec.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...) cmdWithEnv := execer.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...)
if err != nil {
rpty.state.setState(StateDone, xerrors.Errorf("pty command context: %w", err))
return rpty
}
cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color") cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color")
cmdWithEnv.Dir = rpty.command.Dir cmdWithEnv.Dir = rpty.command.Dir
ptty, process, err := pty.Start(cmdWithEnv) ptty, process, err := pty.Start(cmdWithEnv)

View File

@ -14,6 +14,7 @@ import (
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/pty" "github.com/coder/coder/v2/pty"
) )
@ -55,7 +56,7 @@ type ReconnectingPTY interface {
// close itself (and all connections to it) if nothing is attached for the // close itself (and all connections to it) if nothing is attached for the
// duration of the timeout, if the context ends, or the process exits (buffered // duration of the timeout, if the context ends, or the process exits (buffered
// backend only). // backend only).
func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) ReconnectingPTY { func New(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) ReconnectingPTY {
if options.Timeout == 0 { if options.Timeout == 0 {
options.Timeout = 5 * time.Minute options.Timeout = 5 * time.Minute
} }
@ -75,9 +76,9 @@ func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger
switch backendType { switch backendType {
case "screen": case "screen":
return newScreen(ctx, cmd, options, logger) return newScreen(ctx, logger, execer, cmd, options)
default: default:
return newBuffered(ctx, cmd, options, logger) return newBuffered(ctx, logger, execer, cmd, options)
} }
} }

View File

@ -25,6 +25,7 @@ import (
// screenReconnectingPTY provides a reconnectable PTY via `screen`. // screenReconnectingPTY provides a reconnectable PTY via `screen`.
type screenReconnectingPTY struct { type screenReconnectingPTY struct {
execer agentexec.Execer
command *pty.Cmd command *pty.Cmd
// id holds the id of the session for both creating and attaching. This will // id holds the id of the session for both creating and attaching. This will
@ -59,8 +60,9 @@ type screenReconnectingPTY struct {
// spawns the daemon with a hardcoded 24x80 size it is not a very good user // spawns the daemon with a hardcoded 24x80 size it is not a very good user
// experience. Instead we will let the attach command spawn the daemon on its // experience. Instead we will let the attach command spawn the daemon on its
// own which causes it to spawn with the specified size. // own which causes it to spawn with the specified size.
func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *screenReconnectingPTY { func newScreen(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *screenReconnectingPTY {
rpty := &screenReconnectingPTY{ rpty := &screenReconnectingPTY{
execer: execer,
command: cmd, command: cmd,
metrics: options.Metrics, metrics: options.Metrics,
state: newState(), state: newState(),
@ -210,7 +212,7 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id)) logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id))
// Wrap the command with screen and tie it to the connection's context. // Wrap the command with screen and tie it to the connection's context.
cmd, err := agentexec.PTYCommandContext(ctx, "screen", append([]string{ cmd := rpty.execer.PTYCommandContext(ctx, "screen", append([]string{
// -S is for setting the session's name. // -S is for setting the session's name.
"-S", rpty.id, "-S", rpty.id,
// -U tells screen to use UTF-8 encoding. // -U tells screen to use UTF-8 encoding.
@ -223,9 +225,6 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
rpty.command.Path, rpty.command.Path,
// pty.Cmd duplicates Path as the first argument so remove it. // pty.Cmd duplicates Path as the first argument so remove it.
}, rpty.command.Args[1:]...)...) }, rpty.command.Args[1:]...)...)
if err != nil {
return nil, nil, xerrors.Errorf("pty command context: %w", err)
}
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color") cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
cmd.Dir = rpty.command.Dir cmd.Dir = rpty.command.Dir
ptty, process, err := pty.Start(cmd, pty.WithPTYOption( ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
@ -333,7 +332,7 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
run := func() (bool, error) { run := func() (bool, error) {
var stdout bytes.Buffer var stdout bytes.Buffer
//nolint:gosec //nolint:gosec
cmd, err := agentexec.CommandContext(ctx, "screen", cmd := rpty.execer.CommandContext(ctx, "screen",
// -x targets an attached session. // -x targets an attached session.
"-x", rpty.id, "-x", rpty.id,
// -c is the flag for the config file. // -c is the flag for the config file.
@ -341,13 +340,10 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
// -X runs a command in the matching session. // -X runs a command in the matching session.
"-X", command, "-X", command,
) )
if err != nil {
return false, xerrors.Errorf("command context: %w", err)
}
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color") cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
cmd.Dir = rpty.command.Dir cmd.Dir = rpty.command.Dir
cmd.Stdout = &stdout cmd.Stdout = &stdout
err = cmd.Run() err := cmd.Run()
if err == nil { if err == nil {
return true, nil return true, nil
} }

View File

@ -165,10 +165,15 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
return xerrors.Errorf("create command: %w", err) return xerrors.Errorf("create command: %w", err)
} }
rpty = New(ctx, cmd, &Options{ rpty = New(ctx,
Timeout: s.timeout, logger.With(slog.F("message_id", msg.ID)),
Metrics: s.errorsTotal, s.commandCreator.Execer,
}, logger.With(slog.F("message_id", msg.ID))) cmd,
&Options{
Timeout: s.timeout,
Metrics: s.errorsTotal,
},
)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {

View File

@ -309,6 +309,11 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
) )
} }
execer, err := agentexec.NewExecer()
if err != nil {
return xerrors.Errorf("create agent execer: %w", err)
}
agnt := agent.New(agent.Options{ agnt := agent.New(agent.Options{
Client: client, Client: client,
Logger: logger, Logger: logger,
@ -333,6 +338,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
PrometheusRegistry: prometheusRegistry, PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer, BlockFileTransfer: blockFileTransfer,
Execer: execer,
}) })
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger) promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)

View File

@ -503,7 +503,7 @@ func noExecInAgent(m dsl.Matcher) {
!m.File().PkgPath.Matches("/agentexec") && !m.File().PkgPath.Matches("/agentexec") &&
!m.File().Name.Matches(`_test\.go$`), !m.File().Name.Matches(`_test\.go$`),
). ).
Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using agentexec.CommandContext instead.") Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using an agentexec.Execer instead.")
} }
// noPTYInAgent ensures that packages under agent/ don't use pty.Command or // noPTYInAgent ensures that packages under agent/ don't use pty.Command or
@ -521,5 +521,5 @@ func noPTYInAgent(m dsl.Matcher) {
!m.File().PkgPath.Matches(`/agentexec`) && !m.File().PkgPath.Matches(`/agentexec`) &&
!m.File().Name.Matches(`_test\.go$`), !m.File().Name.Matches(`_test\.go$`),
). ).
Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using agentexec.PTYCommandContext instead.") Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using an agentexec.Execer instead.")
} }