mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
fix: add agent exec abstraction (#15717)
This commit is contained in:
@ -33,6 +33,7 @@ import (
|
|||||||
"tailscale.com/util/clientmetric"
|
"tailscale.com/util/clientmetric"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/agent/agentexec"
|
||||||
"github.com/coder/coder/v2/agent/agentscripts"
|
"github.com/coder/coder/v2/agent/agentscripts"
|
||||||
"github.com/coder/coder/v2/agent/agentssh"
|
"github.com/coder/coder/v2/agent/agentssh"
|
||||||
"github.com/coder/coder/v2/agent/proto"
|
"github.com/coder/coder/v2/agent/proto"
|
||||||
@ -80,6 +81,7 @@ type Options struct {
|
|||||||
ReportMetadataInterval time.Duration
|
ReportMetadataInterval time.Duration
|
||||||
ServiceBannerRefreshInterval time.Duration
|
ServiceBannerRefreshInterval time.Duration
|
||||||
BlockFileTransfer bool
|
BlockFileTransfer bool
|
||||||
|
Execer agentexec.Execer
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
@ -139,6 +141,10 @@ func New(options Options) Agent {
|
|||||||
prometheusRegistry = prometheus.NewRegistry()
|
prometheusRegistry = prometheus.NewRegistry()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if options.Execer == nil {
|
||||||
|
options.Execer = agentexec.DefaultExecer
|
||||||
|
}
|
||||||
|
|
||||||
hardCtx, hardCancel := context.WithCancel(context.Background())
|
hardCtx, hardCancel := context.WithCancel(context.Background())
|
||||||
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
|
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
|
||||||
a := &agent{
|
a := &agent{
|
||||||
@ -171,6 +177,7 @@ func New(options Options) Agent {
|
|||||||
|
|
||||||
prometheusRegistry: prometheusRegistry,
|
prometheusRegistry: prometheusRegistry,
|
||||||
metrics: newAgentMetrics(prometheusRegistry),
|
metrics: newAgentMetrics(prometheusRegistry),
|
||||||
|
execer: options.Execer,
|
||||||
}
|
}
|
||||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||||
@ -239,6 +246,7 @@ type agent struct {
|
|||||||
// metrics are prometheus registered metrics that will be collected and
|
// metrics are prometheus registered metrics that will be collected and
|
||||||
// labeled in Coder with the agent + workspace.
|
// labeled in Coder with the agent + workspace.
|
||||||
metrics *agentMetrics
|
metrics *agentMetrics
|
||||||
|
execer agentexec.Execer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||||
@ -247,7 +255,7 @@ func (a *agent) TailnetConn() *tailnet.Conn {
|
|||||||
|
|
||||||
func (a *agent) init() {
|
func (a *agent) init() {
|
||||||
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
||||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
|
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
|
||||||
MaxTimeout: a.sshMaxTimeout,
|
MaxTimeout: a.sshMaxTimeout,
|
||||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||||
|
@ -17,9 +17,6 @@ import (
|
|||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// unset is set to an invalid value for nice and oom scores.
|
|
||||||
const unset = -2000
|
|
||||||
|
|
||||||
// CLI runs the agent-exec command. It should only be called by the cli package.
|
// CLI runs the agent-exec command. It should only be called by the cli package.
|
||||||
func CLI() error {
|
func CLI() error {
|
||||||
// We lock the OS thread here to avoid a race condition where the nice priority
|
// We lock the OS thread here to avoid a race condition where the nice priority
|
||||||
|
@ -20,60 +20,101 @@ const (
|
|||||||
EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
|
EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
|
||||||
EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
|
EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
|
||||||
EnvProcNiceScore = "CODER_PROC_NICE_SCORE"
|
EnvProcNiceScore = "CODER_PROC_NICE_SCORE"
|
||||||
|
|
||||||
|
// unset is set to an invalid value for nice and oom scores.
|
||||||
|
unset = -2000
|
||||||
)
|
)
|
||||||
|
|
||||||
// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
|
var DefaultExecer Execer = execer{}
|
||||||
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
|
|
||||||
// is returned. All instances of exec.Cmd should flow through this function to ensure
|
// Execer defines an abstraction for creating exec.Cmd variants. It's unfortunately
|
||||||
// proper resource constraints are applied to the child process.
|
// necessary because we need to be able to wrap child processes with "coder agent-exec"
|
||||||
func CommandContext(ctx context.Context, cmd string, args ...string) (*exec.Cmd, error) {
|
// for templates that expect the agent to manage process priority.
|
||||||
cmd, args, err := agentExecCmd(cmd, args...)
|
type Execer interface {
|
||||||
if err != nil {
|
// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
|
||||||
return nil, xerrors.Errorf("agent exec cmd: %w", err)
|
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
|
||||||
}
|
// is returned. All instances of exec.Cmd should flow through this function to ensure
|
||||||
return exec.CommandContext(ctx, cmd, args...), nil
|
// proper resource constraints are applied to the child process.
|
||||||
|
CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd
|
||||||
|
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing
|
||||||
|
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
|
||||||
|
// is returned. All instances of pty.Cmd should flow through this function to ensure
|
||||||
|
// proper resource constraints are applied to the child process.
|
||||||
|
PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
// PTYCommandContext returns an pty.Cmd that calls "coder agent-exec" prior to exec'ing
|
func NewExecer() (Execer, error) {
|
||||||
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal pty.Cmd
|
|
||||||
// is returned. All instances of pty.Cmd should flow through this function to ensure
|
|
||||||
// proper resource constraints are applied to the child process.
|
|
||||||
func PTYCommandContext(ctx context.Context, cmd string, args ...string) (*pty.Cmd, error) {
|
|
||||||
cmd, args, err := agentExecCmd(cmd, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("agent exec cmd: %w", err)
|
|
||||||
}
|
|
||||||
return pty.CommandContext(ctx, cmd, args...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func agentExecCmd(cmd string, args ...string) (string, []string, error) {
|
|
||||||
_, enabled := os.LookupEnv(EnvProcPrioMgmt)
|
_, enabled := os.LookupEnv(EnvProcPrioMgmt)
|
||||||
if runtime.GOOS != "linux" || !enabled {
|
if runtime.GOOS != "linux" || !enabled {
|
||||||
return cmd, args, nil
|
return DefaultExecer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
executable, err := os.Executable()
|
executable, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, xerrors.Errorf("get executable: %w", err)
|
return nil, xerrors.Errorf("get executable: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
bin, err := filepath.EvalSymlinks(executable)
|
bin, err := filepath.EvalSymlinks(executable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, xerrors.Errorf("eval symlinks: %w", err)
|
return nil, xerrors.Errorf("eval symlinks: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oomScore, ok := envValInt(EnvProcOOMScore)
|
||||||
|
if !ok {
|
||||||
|
oomScore = unset
|
||||||
|
}
|
||||||
|
|
||||||
|
niceScore, ok := envValInt(EnvProcNiceScore)
|
||||||
|
if !ok {
|
||||||
|
niceScore = unset
|
||||||
|
}
|
||||||
|
|
||||||
|
return priorityExecer{
|
||||||
|
binPath: bin,
|
||||||
|
oomScore: oomScore,
|
||||||
|
niceScore: niceScore,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type execer struct{}
|
||||||
|
|
||||||
|
func (execer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
|
||||||
|
return exec.CommandContext(ctx, cmd, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (execer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
|
||||||
|
return pty.CommandContext(ctx, cmd, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type priorityExecer struct {
|
||||||
|
binPath string
|
||||||
|
oomScore int
|
||||||
|
niceScore int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e priorityExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
|
||||||
|
cmd, args = e.agentExecCmd(cmd, args...)
|
||||||
|
return exec.CommandContext(ctx, cmd, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e priorityExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
|
||||||
|
cmd, args = e.agentExecCmd(cmd, args...)
|
||||||
|
return pty.CommandContext(ctx, cmd, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e priorityExecer) agentExecCmd(cmd string, args ...string) (string, []string) {
|
||||||
execArgs := []string{"agent-exec"}
|
execArgs := []string{"agent-exec"}
|
||||||
if score, ok := envValInt(EnvProcOOMScore); ok {
|
if e.oomScore != unset {
|
||||||
execArgs = append(execArgs, oomScoreArg(score))
|
execArgs = append(execArgs, oomScoreArg(e.oomScore))
|
||||||
}
|
}
|
||||||
|
|
||||||
if score, ok := envValInt(EnvProcNiceScore); ok {
|
if e.niceScore != unset {
|
||||||
execArgs = append(execArgs, niceScoreArg(score))
|
execArgs = append(execArgs, niceScoreArg(e.niceScore))
|
||||||
}
|
}
|
||||||
execArgs = append(execArgs, "--", cmd)
|
execArgs = append(execArgs, "--", cmd)
|
||||||
execArgs = append(execArgs, args...)
|
execArgs = append(execArgs, args...)
|
||||||
|
|
||||||
return bin, execArgs, nil
|
return e.binPath, execArgs
|
||||||
}
|
}
|
||||||
|
|
||||||
// envValInt searches for a key in a list of environment variables and parses it to an int.
|
// envValInt searches for a key in a list of environment variables and parses it to an int.
|
||||||
|
84
agent/agentexec/exec_internal_test.go
Normal file
84
agent/agentexec/exec_internal_test.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package agentexec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExecer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("Default", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cmd := DefaultExecer.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||||
|
|
||||||
|
path, err := exec.LookPath("sh")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, path, cmd.Path)
|
||||||
|
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Priority", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("OK", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
e := priorityExecer{
|
||||||
|
binPath: "/foo/bar/baz",
|
||||||
|
oomScore: unset,
|
||||||
|
niceScore: unset,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||||
|
require.Equal(t, e.binPath, cmd.Path)
|
||||||
|
require.Equal(t, []string{e.binPath, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Nice", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
e := priorityExecer{
|
||||||
|
binPath: "/foo/bar/baz",
|
||||||
|
oomScore: unset,
|
||||||
|
niceScore: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||||
|
require.Equal(t, e.binPath, cmd.Path)
|
||||||
|
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OOM", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
e := priorityExecer{
|
||||||
|
binPath: "/foo/bar/baz",
|
||||||
|
oomScore: 123,
|
||||||
|
niceScore: unset,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||||
|
require.Equal(t, e.binPath, cmd.Path)
|
||||||
|
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Both", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
e := priorityExecer{
|
||||||
|
binPath: "/foo/bar/baz",
|
||||||
|
oomScore: 432,
|
||||||
|
niceScore: 14,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := e.CommandContext(context.Background(), "sh", "-c", "sleep")
|
||||||
|
require.Equal(t, e.binPath, cmd.Path)
|
||||||
|
require.Equal(t, []string{e.binPath, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
@ -1,119 +0,0 @@
|
|||||||
package agentexec_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/coder/coder/v2/agent/agentexec"
|
|
||||||
)
|
|
||||||
|
|
||||||
//nolint:paralleltest // we need to test environment variables
|
|
||||||
func TestExec(t *testing.T) {
|
|
||||||
//nolint:paralleltest // we need to test environment variables
|
|
||||||
t.Run("NonLinux", func(t *testing.T) {
|
|
||||||
t.Setenv(agentexec.EnvProcPrioMgmt, "true")
|
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
t.Skip("skipping on linux")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
path, err := exec.LookPath("sh")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, path, cmd.Path)
|
|
||||||
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
|
|
||||||
})
|
|
||||||
|
|
||||||
//nolint:paralleltest // we need to test environment variables
|
|
||||||
t.Run("Linux", func(t *testing.T) {
|
|
||||||
//nolint:paralleltest // we need to test environment variables
|
|
||||||
t.Run("Disabled", func(t *testing.T) {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
t.Skip("skipping on non-linux")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
|
||||||
require.NoError(t, err)
|
|
||||||
path, err := exec.LookPath("sh")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, path, cmd.Path)
|
|
||||||
require.Equal(t, []string{"sh", "-c", "sleep"}, cmd.Args)
|
|
||||||
})
|
|
||||||
|
|
||||||
//nolint:paralleltest // we need to test environment variables
|
|
||||||
t.Run("Enabled", func(t *testing.T) {
|
|
||||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
|
||||||
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
t.Skip("skipping on non-linux")
|
|
||||||
}
|
|
||||||
|
|
||||||
executable, err := os.Executable()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, executable, cmd.Path)
|
|
||||||
require.Equal(t, []string{executable, "agent-exec", "--", "sh", "-c", "sleep"}, cmd.Args)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Nice", func(t *testing.T) {
|
|
||||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
|
||||||
t.Setenv(agentexec.EnvProcNiceScore, "10")
|
|
||||||
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
t.Skip("skipping on non-linux")
|
|
||||||
}
|
|
||||||
|
|
||||||
executable, err := os.Executable()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, executable, cmd.Path)
|
|
||||||
require.Equal(t, []string{executable, "agent-exec", "--coder-nice=10", "--", "sh", "-c", "sleep"}, cmd.Args)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("OOM", func(t *testing.T) {
|
|
||||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
|
||||||
t.Setenv(agentexec.EnvProcOOMScore, "123")
|
|
||||||
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
t.Skip("skipping on non-linux")
|
|
||||||
}
|
|
||||||
|
|
||||||
executable, err := os.Executable()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, executable, cmd.Path)
|
|
||||||
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=123", "--", "sh", "-c", "sleep"}, cmd.Args)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Both", func(t *testing.T) {
|
|
||||||
t.Setenv(agentexec.EnvProcPrioMgmt, "hello")
|
|
||||||
t.Setenv(agentexec.EnvProcOOMScore, "432")
|
|
||||||
t.Setenv(agentexec.EnvProcNiceScore, "14")
|
|
||||||
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
t.Skip("skipping on non-linux")
|
|
||||||
}
|
|
||||||
|
|
||||||
executable, err := os.Executable()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cmd, err := agentexec.CommandContext(context.Background(), "sh", "-c", "sleep")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, executable, cmd.Path)
|
|
||||||
require.Equal(t, []string{executable, "agent-exec", "--coder-oom=432", "--coder-nice=14", "--", "sh", "-c", "sleep"}, cmd.Args)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/agent/agentexec"
|
||||||
"github.com/coder/coder/v2/agent/agentscripts"
|
"github.com/coder/coder/v2/agent/agentscripts"
|
||||||
"github.com/coder/coder/v2/agent/agentssh"
|
"github.com/coder/coder/v2/agent/agentssh"
|
||||||
"github.com/coder/coder/v2/agent/agenttest"
|
"github.com/coder/coder/v2/agent/agenttest"
|
||||||
@ -160,7 +161,7 @@ func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscript
|
|||||||
}
|
}
|
||||||
fs := afero.NewMemMapFs()
|
fs := afero.NewMemMapFs()
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil)
|
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
_ = s.Close()
|
_ = s.Close()
|
||||||
|
@ -98,6 +98,7 @@ type Server struct {
|
|||||||
// a lock on mu but protected by closing.
|
// a lock on mu but protected by closing.
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
Execer agentexec.Execer
|
||||||
logger slog.Logger
|
logger slog.Logger
|
||||||
srv *ssh.Server
|
srv *ssh.Server
|
||||||
|
|
||||||
@ -110,7 +111,7 @@ type Server struct {
|
|||||||
metrics *sshServerMetrics
|
metrics *sshServerMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) {
|
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) {
|
||||||
// Clients' should ignore the host key when connecting.
|
// Clients' should ignore the host key when connecting.
|
||||||
// The agent needs to authenticate with coderd to SSH,
|
// The agent needs to authenticate with coderd to SSH,
|
||||||
// so SSH authentication doesn't improve security.
|
// so SSH authentication doesn't improve security.
|
||||||
@ -153,6 +154,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
|||||||
|
|
||||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||||
s := &Server{
|
s := &Server{
|
||||||
|
Execer: execer,
|
||||||
listeners: make(map[net.Listener]struct{}),
|
listeners: make(map[net.Listener]struct{}),
|
||||||
fs: fs,
|
fs: fs,
|
||||||
conns: make(map[net.Conn]struct{}),
|
conns: make(map[net.Conn]struct{}),
|
||||||
@ -726,10 +728,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd, err := agentexec.PTYCommandContext(ctx, name, args...)
|
cmd := s.Execer.PTYCommandContext(ctx, name, args...)
|
||||||
if err != nil {
|
|
||||||
return nil, xerrors.Errorf("pty command context: %w", err)
|
|
||||||
}
|
|
||||||
cmd.Dir = s.config.WorkingDirectory()
|
cmd.Dir = s.config.WorkingDirectory()
|
||||||
|
|
||||||
// If the metadata directory doesn't exist, we run the command
|
// If the metadata directory doesn't exist, we run the command
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/agent/agentexec"
|
||||||
"github.com/coder/coder/v2/pty"
|
"github.com/coder/coder/v2/pty"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
@ -35,7 +36,7 @@ func Test_sessionStart_orphan(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/agent/agentexec"
|
||||||
"github.com/coder/coder/v2/agent/agentssh"
|
"github.com/coder/coder/v2/agent/agentssh"
|
||||||
"github.com/coder/coder/v2/pty/ptytest"
|
"github.com/coder/coder/v2/pty/ptytest"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
@ -36,7 +37,7 @@ func TestNewServer_ServeClient(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
_ = s.Close()
|
_ = s.Close()
|
||||||
@ -108,7 +109,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
@ -159,7 +160,7 @@ func TestNewServer_Signal(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
@ -224,7 +225,7 @@ func TestNewServer_Signal(t *testing.T) {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
gossh "golang.org/x/crypto/ssh"
|
gossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/agent/agentexec"
|
||||||
"github.com/coder/coder/v2/agent/agentssh"
|
"github.com/coder/coder/v2/agent/agentssh"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
@ -34,7 +35,7 @@ func TestServer_X11(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
logger := testutil.Logger(t)
|
logger := testutil.Logger(t)
|
||||||
fs := afero.NewOsFs()
|
fs := afero.NewOsFs()
|
||||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ type bufferedReconnectingPTY struct {
|
|||||||
|
|
||||||
// newBuffered starts the buffered pty. If the context ends the process will be
|
// newBuffered starts the buffered pty. If the context ends the process will be
|
||||||
// killed.
|
// killed.
|
||||||
func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *bufferedReconnectingPTY {
|
func newBuffered(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *bufferedReconnectingPTY {
|
||||||
rpty := &bufferedReconnectingPTY{
|
rpty := &bufferedReconnectingPTY{
|
||||||
activeConns: map[string]net.Conn{},
|
activeConns: map[string]net.Conn{},
|
||||||
command: cmd,
|
command: cmd,
|
||||||
@ -59,11 +59,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
|
|||||||
|
|
||||||
// Add TERM then start the command with a pty. pty.Cmd duplicates Path as the
|
// Add TERM then start the command with a pty. pty.Cmd duplicates Path as the
|
||||||
// first argument so remove it.
|
// first argument so remove it.
|
||||||
cmdWithEnv, err := agentexec.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...)
|
cmdWithEnv := execer.PTYCommandContext(ctx, cmd.Path, cmd.Args[1:]...)
|
||||||
if err != nil {
|
|
||||||
rpty.state.setState(StateDone, xerrors.Errorf("pty command context: %w", err))
|
|
||||||
return rpty
|
|
||||||
}
|
|
||||||
cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
cmdWithEnv.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
||||||
cmdWithEnv.Dir = rpty.command.Dir
|
cmdWithEnv.Dir = rpty.command.Dir
|
||||||
ptty, process, err := pty.Start(cmdWithEnv)
|
ptty, process, err := pty.Start(cmdWithEnv)
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/agent/agentexec"
|
||||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||||
"github.com/coder/coder/v2/pty"
|
"github.com/coder/coder/v2/pty"
|
||||||
)
|
)
|
||||||
@ -55,7 +56,7 @@ type ReconnectingPTY interface {
|
|||||||
// close itself (and all connections to it) if nothing is attached for the
|
// close itself (and all connections to it) if nothing is attached for the
|
||||||
// duration of the timeout, if the context ends, or the process exits (buffered
|
// duration of the timeout, if the context ends, or the process exits (buffered
|
||||||
// backend only).
|
// backend only).
|
||||||
func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) ReconnectingPTY {
|
func New(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) ReconnectingPTY {
|
||||||
if options.Timeout == 0 {
|
if options.Timeout == 0 {
|
||||||
options.Timeout = 5 * time.Minute
|
options.Timeout = 5 * time.Minute
|
||||||
}
|
}
|
||||||
@ -75,9 +76,9 @@ func New(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger
|
|||||||
|
|
||||||
switch backendType {
|
switch backendType {
|
||||||
case "screen":
|
case "screen":
|
||||||
return newScreen(ctx, cmd, options, logger)
|
return newScreen(ctx, logger, execer, cmd, options)
|
||||||
default:
|
default:
|
||||||
return newBuffered(ctx, cmd, options, logger)
|
return newBuffered(ctx, logger, execer, cmd, options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ import (
|
|||||||
|
|
||||||
// screenReconnectingPTY provides a reconnectable PTY via `screen`.
|
// screenReconnectingPTY provides a reconnectable PTY via `screen`.
|
||||||
type screenReconnectingPTY struct {
|
type screenReconnectingPTY struct {
|
||||||
|
execer agentexec.Execer
|
||||||
command *pty.Cmd
|
command *pty.Cmd
|
||||||
|
|
||||||
// id holds the id of the session for both creating and attaching. This will
|
// id holds the id of the session for both creating and attaching. This will
|
||||||
@ -59,8 +60,9 @@ type screenReconnectingPTY struct {
|
|||||||
// spawns the daemon with a hardcoded 24x80 size it is not a very good user
|
// spawns the daemon with a hardcoded 24x80 size it is not a very good user
|
||||||
// experience. Instead we will let the attach command spawn the daemon on its
|
// experience. Instead we will let the attach command spawn the daemon on its
|
||||||
// own which causes it to spawn with the specified size.
|
// own which causes it to spawn with the specified size.
|
||||||
func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.Logger) *screenReconnectingPTY {
|
func newScreen(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *screenReconnectingPTY {
|
||||||
rpty := &screenReconnectingPTY{
|
rpty := &screenReconnectingPTY{
|
||||||
|
execer: execer,
|
||||||
command: cmd,
|
command: cmd,
|
||||||
metrics: options.Metrics,
|
metrics: options.Metrics,
|
||||||
state: newState(),
|
state: newState(),
|
||||||
@ -210,7 +212,7 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
|
|||||||
logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id))
|
logger.Debug(ctx, "spawning screen client", slog.F("screen_id", rpty.id))
|
||||||
|
|
||||||
// Wrap the command with screen and tie it to the connection's context.
|
// Wrap the command with screen and tie it to the connection's context.
|
||||||
cmd, err := agentexec.PTYCommandContext(ctx, "screen", append([]string{
|
cmd := rpty.execer.PTYCommandContext(ctx, "screen", append([]string{
|
||||||
// -S is for setting the session's name.
|
// -S is for setting the session's name.
|
||||||
"-S", rpty.id,
|
"-S", rpty.id,
|
||||||
// -U tells screen to use UTF-8 encoding.
|
// -U tells screen to use UTF-8 encoding.
|
||||||
@ -223,9 +225,6 @@ func (rpty *screenReconnectingPTY) doAttach(ctx context.Context, conn net.Conn,
|
|||||||
rpty.command.Path,
|
rpty.command.Path,
|
||||||
// pty.Cmd duplicates Path as the first argument so remove it.
|
// pty.Cmd duplicates Path as the first argument so remove it.
|
||||||
}, rpty.command.Args[1:]...)...)
|
}, rpty.command.Args[1:]...)...)
|
||||||
if err != nil {
|
|
||||||
return nil, nil, xerrors.Errorf("pty command context: %w", err)
|
|
||||||
}
|
|
||||||
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
||||||
cmd.Dir = rpty.command.Dir
|
cmd.Dir = rpty.command.Dir
|
||||||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
||||||
@ -333,7 +332,7 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
|
|||||||
run := func() (bool, error) {
|
run := func() (bool, error) {
|
||||||
var stdout bytes.Buffer
|
var stdout bytes.Buffer
|
||||||
//nolint:gosec
|
//nolint:gosec
|
||||||
cmd, err := agentexec.CommandContext(ctx, "screen",
|
cmd := rpty.execer.CommandContext(ctx, "screen",
|
||||||
// -x targets an attached session.
|
// -x targets an attached session.
|
||||||
"-x", rpty.id,
|
"-x", rpty.id,
|
||||||
// -c is the flag for the config file.
|
// -c is the flag for the config file.
|
||||||
@ -341,13 +340,10 @@ func (rpty *screenReconnectingPTY) sendCommand(ctx context.Context, command stri
|
|||||||
// -X runs a command in the matching session.
|
// -X runs a command in the matching session.
|
||||||
"-X", command,
|
"-X", command,
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return false, xerrors.Errorf("command context: %w", err)
|
|
||||||
}
|
|
||||||
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
cmd.Env = append(rpty.command.Env, "TERM=xterm-256color")
|
||||||
cmd.Dir = rpty.command.Dir
|
cmd.Dir = rpty.command.Dir
|
||||||
cmd.Stdout = &stdout
|
cmd.Stdout = &stdout
|
||||||
err = cmd.Run()
|
err := cmd.Run()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -165,10 +165,15 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
|
|||||||
return xerrors.Errorf("create command: %w", err)
|
return xerrors.Errorf("create command: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rpty = New(ctx, cmd, &Options{
|
rpty = New(ctx,
|
||||||
Timeout: s.timeout,
|
logger.With(slog.F("message_id", msg.ID)),
|
||||||
Metrics: s.errorsTotal,
|
s.commandCreator.Execer,
|
||||||
}, logger.With(slog.F("message_id", msg.ID)))
|
cmd,
|
||||||
|
&Options{
|
||||||
|
Timeout: s.timeout,
|
||||||
|
Metrics: s.errorsTotal,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -309,6 +309,11 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
execer, err := agentexec.NewExecer()
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("create agent execer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
agnt := agent.New(agent.Options{
|
agnt := agent.New(agent.Options{
|
||||||
Client: client,
|
Client: client,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
@ -333,6 +338,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
|
|||||||
|
|
||||||
PrometheusRegistry: prometheusRegistry,
|
PrometheusRegistry: prometheusRegistry,
|
||||||
BlockFileTransfer: blockFileTransfer,
|
BlockFileTransfer: blockFileTransfer,
|
||||||
|
Execer: execer,
|
||||||
})
|
})
|
||||||
|
|
||||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||||
|
@ -503,7 +503,7 @@ func noExecInAgent(m dsl.Matcher) {
|
|||||||
!m.File().PkgPath.Matches("/agentexec") &&
|
!m.File().PkgPath.Matches("/agentexec") &&
|
||||||
!m.File().Name.Matches(`_test\.go$`),
|
!m.File().Name.Matches(`_test\.go$`),
|
||||||
).
|
).
|
||||||
Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using agentexec.CommandContext instead.")
|
Report("The agent and its subpackages should not use exec.Command or exec.CommandContext directly. Consider using an agentexec.Execer instead.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// noPTYInAgent ensures that packages under agent/ don't use pty.Command or
|
// noPTYInAgent ensures that packages under agent/ don't use pty.Command or
|
||||||
@ -521,5 +521,5 @@ func noPTYInAgent(m dsl.Matcher) {
|
|||||||
!m.File().PkgPath.Matches(`/agentexec`) &&
|
!m.File().PkgPath.Matches(`/agentexec`) &&
|
||||||
!m.File().Name.Matches(`_test\.go$`),
|
!m.File().Name.Matches(`_test\.go$`),
|
||||||
).
|
).
|
||||||
Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using agentexec.PTYCommandContext instead.")
|
Report("The agent and its subpackages should not use pty.Command or pty.CommandContext directly. Consider using an agentexec.Execer instead.")
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user