diff --git a/agent/agent.go b/agent/agent.go index 44f55fcedc..cdab56d935 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -15,6 +15,8 @@ import ( "os/exec" "os/user" "path/filepath" + "runtime" + "runtime/debug" "sort" "strconv" "strings" @@ -34,6 +36,7 @@ import ( "tailscale.com/types/netlogtype" "cdr.dev/slog" + "github.com/coder/coder/v2/agent/agentproc" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/reconnectingpty" "github.com/coder/coder/v2/buildinfo" @@ -51,6 +54,10 @@ const ( ProtocolDial = "dial" ) +// EnvProcPrioMgmt determines whether we attempt to manage +// process CPU and OOM Killer priority. +const EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT" + type Options struct { Filesystem afero.Fs LogDir string @@ -68,6 +75,11 @@ type Options struct { PrometheusRegistry *prometheus.Registry ReportMetadataInterval time.Duration ServiceBannerRefreshInterval time.Duration + Syscaller agentproc.Syscaller + // ModifiedProcesses is used for testing process priority management. + ModifiedProcesses chan []*agentproc.Process + // ProcessManagementTick is used for testing process priority management. + ProcessManagementTick <-chan time.Time } type Client interface { @@ -120,6 +132,10 @@ func New(options Options) Agent { prometheusRegistry = prometheus.NewRegistry() } + if options.Syscaller == nil { + options.Syscaller = agentproc.NewSyscaller() + } + ctx, cancelFunc := context.WithCancel(context.Background()) a := &agent{ tailnetListenPort: options.TailnetListenPort, @@ -143,6 +159,9 @@ func New(options Options) Agent { sshMaxTimeout: options.SSHMaxTimeout, subsystems: options.Subsystems, addresses: options.Addresses, + syscaller: options.Syscaller, + modifiedProcs: options.ModifiedProcesses, + processManagementTick: options.ProcessManagementTick, prometheusRegistry: prometheusRegistry, metrics: newAgentMetrics(prometheusRegistry), @@ -197,6 +216,12 @@ type agent struct { prometheusRegistry *prometheus.Registry metrics *agentMetrics + syscaller agentproc.Syscaller + + // modifiedProcs is used for testing process priority management. + modifiedProcs chan []*agentproc.Process + // processManagementTick is used for testing process priority management. + processManagementTick <-chan time.Time } func (a *agent) TailnetConn() *tailnet.Conn { @@ -225,6 +250,7 @@ func (a *agent) runLoop(ctx context.Context) { go a.reportLifecycleLoop(ctx) go a.reportMetadataLoop(ctx) go a.fetchServiceBannerLoop(ctx) + go a.manageProcessPriorityLoop(ctx) for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") @@ -1253,6 +1279,119 @@ func (a *agent) startReportingConnectionStats(ctx context.Context) { } } +var prioritizedProcs = []string{"coder agent"} + +func (a *agent) manageProcessPriorityLoop(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + a.logger.Critical(ctx, "recovered from panic", + slog.F("panic", r), + slog.F("stack", string(debug.Stack())), + ) + } + }() + + if val := a.envVars[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" { + a.logger.Debug(ctx, "process priority not enabled, agent will not manage process niceness/oom_score_adj ", + slog.F("env_var", EnvProcPrioMgmt), + slog.F("value", val), + slog.F("goos", runtime.GOOS), + ) + return + } + + if a.processManagementTick == nil { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + a.processManagementTick = ticker.C + } + + for { + procs, err := a.manageProcessPriority(ctx) + if err != nil { + a.logger.Error(ctx, "manage process priority", + slog.Error(err), + ) + } + if a.modifiedProcs != nil { + a.modifiedProcs <- procs + } + + select { + case <-a.processManagementTick: + case <-ctx.Done(): + return + } + } +} + +func (a *agent) manageProcessPriority(ctx context.Context) ([]*agentproc.Process, error) { + const ( + niceness = 10 + ) + + procs, err := agentproc.List(a.filesystem, a.syscaller) + if err != nil { + return nil, xerrors.Errorf("list: %w", err) + } + + var ( + modProcs = []*agentproc.Process{} + logger slog.Logger + ) + + for _, proc := range procs { + logger = a.logger.With( + slog.F("cmd", proc.Cmd()), + slog.F("pid", proc.PID), + ) + + containsFn := func(e string) bool { + contains := strings.Contains(proc.Cmd(), e) + return contains + } + + // If the process is prioritized we should adjust + // it's oom_score_adj and avoid lowering its niceness. + if slices.ContainsFunc[[]string, string](prioritizedProcs, containsFn) { + continue + } + + score, err := proc.Niceness(a.syscaller) + if err != nil { + logger.Warn(ctx, "unable to get proc niceness", + slog.Error(err), + ) + continue + } + + // We only want processes that don't have a nice value set + // so we don't override user nice values. + // Getpriority actually returns priority for the nice value + // which is niceness + 20, so here 20 = a niceness of 0 (aka unset). + if score != 20 { + if score != niceness { + logger.Debug(ctx, "skipping process due to custom niceness", + slog.F("niceness", score), + ) + } + continue + } + + err = proc.SetNiceness(a.syscaller, niceness) + if err != nil { + logger.Warn(ctx, "unable to set proc niceness", + slog.F("niceness", niceness), + slog.Error(err), + ) + continue + } + + modProcs = append(modProcs, proc) + } + return modProcs, nil +} + // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { select { diff --git a/agent/agent_test.go b/agent/agent_test.go index 126e0f4fa4..80fa7435c7 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -21,10 +21,12 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "testing" "time" scp "github.com/bramvdbogaerde/go-scp" + "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/pion/udp" "github.com/pkg/sftp" @@ -41,8 +43,11 @@ import ( "tailscale.com/tailcfg" "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/agent" + "github.com/coder/coder/v2/agent/agentproc" + "github.com/coder/coder/v2/agent/agentproc/agentproctest" "github.com/coder/coder/v2/agent/agentssh" "github.com/coder/coder/v2/agent/agenttest" "github.com/coder/coder/v2/coderd/httpapi" @@ -2395,6 +2400,173 @@ func TestAgent_Metrics_SSH(t *testing.T) { require.NoError(t, err) } +func TestAgent_ManageProcessPriority(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + if runtime.GOOS != "linux" { + t.Skip("Skipping non-linux environment") + } + + var ( + expectedProcs = map[int32]agentproc.Process{} + fs = afero.NewMemMapFs() + syscaller = agentproctest.NewMockSyscaller(gomock.NewController(t)) + ticker = make(chan time.Time) + modProcs = make(chan []*agentproc.Process) + logger = slog.Make(sloghuman.Sink(io.Discard)) + ) + + // Create some processes. + for i := 0; i < 4; i++ { + // Create a prioritized process. This process should + // have it's oom_score_adj set to -500 and its nice + // score should be untouched. + var proc agentproc.Process + if i == 0 { + proc = agentproctest.GenerateProcess(t, fs, + func(p *agentproc.Process) { + p.CmdLine = "./coder\x00agent\x00--no-reap" + p.PID = int32(i) + }, + ) + } else { + proc = agentproctest.GenerateProcess(t, fs, + func(p *agentproc.Process) { + // Make the cmd something similar to a prioritized + // process but differentiate the arguments. + p.CmdLine = "./coder\x00stat" + }, + ) + + syscaller.EXPECT().SetPriority(proc.PID, 10).Return(nil) + syscaller.EXPECT().GetPriority(proc.PID).Return(20, nil) + } + syscaller.EXPECT(). + Kill(proc.PID, syscall.Signal(0)). + Return(nil) + + expectedProcs[proc.PID] = proc + } + + _, _, _, _, _ = setupAgent(t, agentsdk.Manifest{}, 0, func(c *agenttest.Client, o *agent.Options) { + o.Syscaller = syscaller + o.ModifiedProcesses = modProcs + o.EnvironmentVariables = map[string]string{agent.EnvProcPrioMgmt: "1"} + o.Filesystem = fs + o.Logger = logger + o.ProcessManagementTick = ticker + }) + actualProcs := <-modProcs + require.Len(t, actualProcs, len(expectedProcs)-1) + }) + + t.Run("IgnoreCustomNice", func(t *testing.T) { + t.Parallel() + + if runtime.GOOS != "linux" { + t.Skip("Skipping non-linux environment") + } + + var ( + expectedProcs = map[int32]agentproc.Process{} + fs = afero.NewMemMapFs() + ticker = make(chan time.Time) + syscaller = agentproctest.NewMockSyscaller(gomock.NewController(t)) + modProcs = make(chan []*agentproc.Process) + logger = slog.Make(sloghuman.Sink(io.Discard)) + ) + + // Create some processes. + for i := 0; i < 2; i++ { + proc := agentproctest.GenerateProcess(t, fs) + syscaller.EXPECT(). + Kill(proc.PID, syscall.Signal(0)). + Return(nil) + + if i == 0 { + // Set a random nice score. This one should not be adjusted by + // our management loop. + syscaller.EXPECT().GetPriority(proc.PID).Return(25, nil) + } else { + syscaller.EXPECT().GetPriority(proc.PID).Return(20, nil) + syscaller.EXPECT().SetPriority(proc.PID, 10).Return(nil) + } + + expectedProcs[proc.PID] = proc + } + + _, _, _, _, _ = setupAgent(t, agentsdk.Manifest{}, 0, func(c *agenttest.Client, o *agent.Options) { + o.Syscaller = syscaller + o.ModifiedProcesses = modProcs + o.EnvironmentVariables = map[string]string{agent.EnvProcPrioMgmt: "1"} + o.Filesystem = fs + o.Logger = logger + o.ProcessManagementTick = ticker + }) + actualProcs := <-modProcs + // We should ignore the process with a custom nice score. + require.Len(t, actualProcs, 1) + }) + + t.Run("DisabledByDefault", func(t *testing.T) { + t.Parallel() + + if runtime.GOOS != "linux" { + t.Skip("Skipping non-linux environment") + } + + var ( + buf bytes.Buffer + wr = &syncWriter{ + w: &buf, + } + ) + log := slog.Make(sloghuman.Sink(wr)).Leveled(slog.LevelDebug) + + _, _, _, _, _ = setupAgent(t, agentsdk.Manifest{}, 0, func(c *agenttest.Client, o *agent.Options) { + o.Logger = log + }) + + require.Eventually(t, func() bool { + wr.mu.Lock() + defer wr.mu.Unlock() + return strings.Contains(buf.String(), "process priority not enabled") + }, testutil.WaitLong, testutil.IntervalFast) + }) + + t.Run("DisabledForNonLinux", func(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "linux" { + t.Skip("Skipping linux environment") + } + + var ( + buf bytes.Buffer + wr = &syncWriter{ + w: &buf, + } + ) + log := slog.Make(sloghuman.Sink(wr)).Leveled(slog.LevelDebug) + + _, _, _, _, _ = setupAgent(t, agentsdk.Manifest{}, 0, func(c *agenttest.Client, o *agent.Options) { + o.Logger = log + // Try to enable it so that we can assert that non-linux + // environments are truly disabled. + o.EnvironmentVariables = map[string]string{agent.EnvProcPrioMgmt: "1"} + }) + require.Eventually(t, func() bool { + wr.mu.Lock() + defer wr.mu.Unlock() + + return strings.Contains(buf.String(), "process priority not enabled") + }, testutil.WaitLong, testutil.IntervalFast) + }) +} + func verifyCollectedMetrics(t *testing.T, expected []agentsdk.AgentMetric, actual []*promgo.MetricFamily) bool { t.Helper() @@ -2416,3 +2588,14 @@ func verifyCollectedMetrics(t *testing.T, expected []agentsdk.AgentMetric, actua } return true } + +type syncWriter struct { + mu sync.Mutex + w io.Writer +} + +func (s *syncWriter) Write(p []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.w.Write(p) +} diff --git a/agent/agentproc/agentproctest/doc.go b/agent/agentproc/agentproctest/doc.go new file mode 100644 index 0000000000..5007b36268 --- /dev/null +++ b/agent/agentproc/agentproctest/doc.go @@ -0,0 +1,5 @@ +// Package agentproctest contains utility functions +// for testing process management in the agent. +package agentproctest + +//go:generate mockgen -destination ./syscallermock.go -package agentproctest github.com/coder/coder/v2/agent/agentproc Syscaller diff --git a/agent/agentproc/agentproctest/proc.go b/agent/agentproc/agentproctest/proc.go new file mode 100644 index 0000000000..c36e04ec1c --- /dev/null +++ b/agent/agentproc/agentproctest/proc.go @@ -0,0 +1,49 @@ +package agentproctest + +import ( + "fmt" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/agent/agentproc" + "github.com/coder/coder/v2/cryptorand" +) + +func GenerateProcess(t *testing.T, fs afero.Fs, muts ...func(*agentproc.Process)) agentproc.Process { + t.Helper() + + pid, err := cryptorand.Intn(1<<31 - 1) + require.NoError(t, err) + + arg1, err := cryptorand.String(5) + require.NoError(t, err) + + arg2, err := cryptorand.String(5) + require.NoError(t, err) + + arg3, err := cryptorand.String(5) + require.NoError(t, err) + + cmdline := fmt.Sprintf("%s\x00%s\x00%s", arg1, arg2, arg3) + + process := agentproc.Process{ + CmdLine: cmdline, + PID: int32(pid), + } + + for _, mut := range muts { + mut(&process) + } + + process.Dir = fmt.Sprintf("%s/%d", "/proc", process.PID) + + err = fs.MkdirAll(process.Dir, 0o555) + require.NoError(t, err) + + err = afero.WriteFile(fs, fmt.Sprintf("%s/cmdline", process.Dir), []byte(process.CmdLine), 0o444) + require.NoError(t, err) + + return process +} diff --git a/agent/agentproc/agentproctest/syscallermock.go b/agent/agentproc/agentproctest/syscallermock.go new file mode 100644 index 0000000000..8d9697bc55 --- /dev/null +++ b/agent/agentproc/agentproctest/syscallermock.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/v2/agent/agentproc (interfaces: Syscaller) + +// Package agentproctest is a generated GoMock package. +package agentproctest + +import ( + reflect "reflect" + syscall "syscall" + + gomock "github.com/golang/mock/gomock" +) + +// MockSyscaller is a mock of Syscaller interface. +type MockSyscaller struct { + ctrl *gomock.Controller + recorder *MockSyscallerMockRecorder +} + +// MockSyscallerMockRecorder is the mock recorder for MockSyscaller. +type MockSyscallerMockRecorder struct { + mock *MockSyscaller +} + +// NewMockSyscaller creates a new mock instance. +func NewMockSyscaller(ctrl *gomock.Controller) *MockSyscaller { + mock := &MockSyscaller{ctrl: ctrl} + mock.recorder = &MockSyscallerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSyscaller) EXPECT() *MockSyscallerMockRecorder { + return m.recorder +} + +// GetPriority mocks base method. +func (m *MockSyscaller) GetPriority(arg0 int32) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPriority", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPriority indicates an expected call of GetPriority. +func (mr *MockSyscallerMockRecorder) GetPriority(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPriority", reflect.TypeOf((*MockSyscaller)(nil).GetPriority), arg0) +} + +// Kill mocks base method. +func (m *MockSyscaller) Kill(arg0 int32, arg1 syscall.Signal) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Kill", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Kill indicates an expected call of Kill. +func (mr *MockSyscallerMockRecorder) Kill(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kill", reflect.TypeOf((*MockSyscaller)(nil).Kill), arg0, arg1) +} + +// SetPriority mocks base method. +func (m *MockSyscaller) SetPriority(arg0 int32, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetPriority", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetPriority indicates an expected call of SetPriority. +func (mr *MockSyscallerMockRecorder) SetPriority(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPriority", reflect.TypeOf((*MockSyscaller)(nil).SetPriority), arg0, arg1) +} diff --git a/agent/agentproc/doc.go b/agent/agentproc/doc.go new file mode 100644 index 0000000000..8b15c52c5f --- /dev/null +++ b/agent/agentproc/doc.go @@ -0,0 +1,3 @@ +// Package agentproc contains logic for interfacing with local +// processes running in the same context as the agent. +package agentproc diff --git a/agent/agentproc/proc_other.go b/agent/agentproc/proc_other.go new file mode 100644 index 0000000000..c0c4e2a25c --- /dev/null +++ b/agent/agentproc/proc_other.go @@ -0,0 +1,24 @@ +//go:build !linux +// +build !linux + +package agentproc + +import ( + "github.com/spf13/afero" +) + +func (p *Process) Niceness(sc Syscaller) (int, error) { + return 0, errUnimplemented +} + +func (p *Process) SetNiceness(sc Syscaller, score int) error { + return errUnimplemented +} + +func (p *Process) Cmd() string { + return "" +} + +func List(fs afero.Fs, syscaller Syscaller) ([]*Process, error) { + return nil, errUnimplemented +} diff --git a/agent/agentproc/proc_test.go b/agent/agentproc/proc_test.go new file mode 100644 index 0000000000..3799167950 --- /dev/null +++ b/agent/agentproc/proc_test.go @@ -0,0 +1,166 @@ +package agentproc_test + +import ( + "runtime" + "syscall" + "testing" + + "github.com/golang/mock/gomock" + "github.com/spf13/afero" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/agent/agentproc" + "github.com/coder/coder/v2/agent/agentproc/agentproctest" +) + +func TestList(t *testing.T) { + t.Parallel() + + if runtime.GOOS != "linux" { + t.Skipf("skipping non-linux environment") + } + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + var ( + fs = afero.NewMemMapFs() + sc = agentproctest.NewMockSyscaller(gomock.NewController(t)) + expectedProcs = make(map[int32]agentproc.Process) + ) + + for i := 0; i < 4; i++ { + proc := agentproctest.GenerateProcess(t, fs) + expectedProcs[proc.PID] = proc + + sc.EXPECT(). + Kill(proc.PID, syscall.Signal(0)). + Return(nil) + } + + actualProcs, err := agentproc.List(fs, sc) + require.NoError(t, err) + require.Len(t, actualProcs, len(expectedProcs)) + for _, proc := range actualProcs { + expected, ok := expectedProcs[proc.PID] + require.True(t, ok) + require.Equal(t, expected.PID, proc.PID) + require.Equal(t, expected.CmdLine, proc.CmdLine) + require.Equal(t, expected.Dir, proc.Dir) + } + }) + + t.Run("FinishedProcess", func(t *testing.T) { + t.Parallel() + + var ( + fs = afero.NewMemMapFs() + sc = agentproctest.NewMockSyscaller(gomock.NewController(t)) + expectedProcs = make(map[int32]agentproc.Process) + ) + + for i := 0; i < 3; i++ { + proc := agentproctest.GenerateProcess(t, fs) + expectedProcs[proc.PID] = proc + + sc.EXPECT(). + Kill(proc.PID, syscall.Signal(0)). + Return(nil) + } + + // Create a process that's already finished. We're not adding + // it to the map because it should be skipped over. + proc := agentproctest.GenerateProcess(t, fs) + sc.EXPECT(). + Kill(proc.PID, syscall.Signal(0)). + Return(xerrors.New("os: process already finished")) + + actualProcs, err := agentproc.List(fs, sc) + require.NoError(t, err) + require.Len(t, actualProcs, len(expectedProcs)) + for _, proc := range actualProcs { + expected, ok := expectedProcs[proc.PID] + require.True(t, ok) + require.Equal(t, expected.PID, proc.PID) + require.Equal(t, expected.CmdLine, proc.CmdLine) + require.Equal(t, expected.Dir, proc.Dir) + } + }) + + t.Run("NoSuchProcess", func(t *testing.T) { + t.Parallel() + + var ( + fs = afero.NewMemMapFs() + sc = agentproctest.NewMockSyscaller(gomock.NewController(t)) + expectedProcs = make(map[int32]agentproc.Process) + ) + + for i := 0; i < 3; i++ { + proc := agentproctest.GenerateProcess(t, fs) + expectedProcs[proc.PID] = proc + + sc.EXPECT(). + Kill(proc.PID, syscall.Signal(0)). + Return(nil) + } + + // Create a process that doesn't exist. We're not adding + // it to the map because it should be skipped over. + proc := agentproctest.GenerateProcess(t, fs) + sc.EXPECT(). + Kill(proc.PID, syscall.Signal(0)). + Return(syscall.ESRCH) + + actualProcs, err := agentproc.List(fs, sc) + require.NoError(t, err) + require.Len(t, actualProcs, len(expectedProcs)) + for _, proc := range actualProcs { + expected, ok := expectedProcs[proc.PID] + require.True(t, ok) + require.Equal(t, expected.PID, proc.PID) + require.Equal(t, expected.CmdLine, proc.CmdLine) + require.Equal(t, expected.Dir, proc.Dir) + } + }) +} + +// These tests are not very interesting but they provide some modicum of +// confidence. +func TestProcess(t *testing.T) { + t.Parallel() + + if runtime.GOOS != "linux" { + t.Skipf("skipping non-linux environment") + } + + t.Run("SetNiceness", func(t *testing.T) { + t.Parallel() + + var ( + sc = agentproctest.NewMockSyscaller(gomock.NewController(t)) + proc = &agentproc.Process{ + PID: 32, + } + score = 20 + ) + + sc.EXPECT().SetPriority(proc.PID, score).Return(nil) + err := proc.SetNiceness(sc, score) + require.NoError(t, err) + }) + + t.Run("Cmd", func(t *testing.T) { + t.Parallel() + + var ( + proc = &agentproc.Process{ + CmdLine: "helloworld\x00--arg1\x00--arg2", + } + expectedName = "helloworld --arg1 --arg2" + ) + + require.Equal(t, expectedName, proc.Cmd()) + }) +} diff --git a/agent/agentproc/proc_unix.go b/agent/agentproc/proc_unix.go new file mode 100644 index 0000000000..f52caed52e --- /dev/null +++ b/agent/agentproc/proc_unix.go @@ -0,0 +1,109 @@ +//go:build linux +// +build linux + +package agentproc + +import ( + "errors" + "path/filepath" + "strconv" + "strings" + "syscall" + + "github.com/spf13/afero" + "golang.org/x/xerrors" +) + +func List(fs afero.Fs, syscaller Syscaller) ([]*Process, error) { + d, err := fs.Open(defaultProcDir) + if err != nil { + return nil, xerrors.Errorf("open dir %q: %w", defaultProcDir, err) + } + defer d.Close() + + entries, err := d.Readdirnames(0) + if err != nil { + return nil, xerrors.Errorf("readdirnames: %w", err) + } + + processes := make([]*Process, 0, len(entries)) + for _, entry := range entries { + pid, err := strconv.ParseInt(entry, 10, 32) + if err != nil { + continue + } + + // Check that the process still exists. + exists, err := isProcessExist(syscaller, int32(pid)) + if err != nil { + return nil, xerrors.Errorf("check process exists: %w", err) + } + if !exists { + continue + } + + cmdline, err := afero.ReadFile(fs, filepath.Join(defaultProcDir, entry, "cmdline")) + if err != nil { + var errNo syscall.Errno + if xerrors.As(err, &errNo) && errNo == syscall.EPERM { + continue + } + return nil, xerrors.Errorf("read cmdline: %w", err) + } + processes = append(processes, &Process{ + PID: int32(pid), + CmdLine: string(cmdline), + Dir: filepath.Join(defaultProcDir, entry), + }) + } + + return processes, nil +} + +func isProcessExist(syscaller Syscaller, pid int32) (bool, error) { + err := syscaller.Kill(pid, syscall.Signal(0)) + if err == nil { + return true, nil + } + if err.Error() == "os: process already finished" { + return false, nil + } + + var errno syscall.Errno + if !errors.As(err, &errno) { + return false, err + } + + switch errno { + case syscall.ESRCH: + return false, nil + case syscall.EPERM: + return true, nil + } + + return false, xerrors.Errorf("kill: %w", err) +} + +func (p *Process) Niceness(sc Syscaller) (int, error) { + nice, err := sc.GetPriority(p.PID) + if err != nil { + return 0, xerrors.Errorf("get priority for %q: %w", p.CmdLine, err) + } + return nice, nil +} + +func (p *Process) SetNiceness(sc Syscaller, score int) error { + err := sc.SetPriority(p.PID, score) + if err != nil { + return xerrors.Errorf("set priority for %q: %w", p.CmdLine, err) + } + return nil +} + +func (p *Process) Cmd() string { + return strings.Join(p.cmdLine(), " ") +} + +func (p *Process) cmdLine() []string { + return strings.Split(p.CmdLine, "\x00") +} diff --git a/agent/agentproc/syscaller.go b/agent/agentproc/syscaller.go new file mode 100644 index 0000000000..1cd6640e36 --- /dev/null +++ b/agent/agentproc/syscaller.go @@ -0,0 +1,19 @@ +package agentproc + +import ( + "syscall" +) + +type Syscaller interface { + SetPriority(pid int32, priority int) error + GetPriority(pid int32) (int, error) + Kill(pid int32, sig syscall.Signal) error +} + +const defaultProcDir = "/proc" + +type Process struct { + Dir string + CmdLine string + PID int32 +} diff --git a/agent/agentproc/syscaller_other.go b/agent/agentproc/syscaller_other.go new file mode 100644 index 0000000000..114c553e43 --- /dev/null +++ b/agent/agentproc/syscaller_other.go @@ -0,0 +1,30 @@ +//go:build !linux +// +build !linux + +package agentproc + +import ( + "syscall" + + "golang.org/x/xerrors" +) + +func NewSyscaller() Syscaller { + return nopSyscaller{} +} + +var errUnimplemented = xerrors.New("unimplemented") + +type nopSyscaller struct{} + +func (nopSyscaller) SetPriority(pid int32, priority int) error { + return errUnimplemented +} + +func (nopSyscaller) GetPriority(pid int32) (int, error) { + return 0, errUnimplemented +} + +func (nopSyscaller) Kill(pid int32, sig syscall.Signal) error { + return errUnimplemented +} diff --git a/agent/agentproc/syscaller_unix.go b/agent/agentproc/syscaller_unix.go new file mode 100644 index 0000000000..e63e56b50f --- /dev/null +++ b/agent/agentproc/syscaller_unix.go @@ -0,0 +1,42 @@ +//go:build linux +// +build linux + +package agentproc + +import ( + "syscall" + + "golang.org/x/sys/unix" + "golang.org/x/xerrors" +) + +func NewSyscaller() Syscaller { + return UnixSyscaller{} +} + +type UnixSyscaller struct{} + +func (UnixSyscaller) SetPriority(pid int32, nice int) error { + err := unix.Setpriority(unix.PRIO_PROCESS, int(pid), nice) + if err != nil { + return xerrors.Errorf("set priority: %w", err) + } + return nil +} + +func (UnixSyscaller) GetPriority(pid int32) (int, error) { + nice, err := unix.Getpriority(0, int(pid)) + if err != nil { + return 0, xerrors.Errorf("get priority: %w", err) + } + return nice, nil +} + +func (UnixSyscaller) Kill(pid int32, sig syscall.Signal) error { + err := syscall.Kill(int(pid), sig) + if err != nil { + return xerrors.Errorf("kill: %w", err) + } + + return nil +} diff --git a/cli/agent.go b/cli/agent.go index 8b77c057ef..6a06f4d454 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -29,6 +29,7 @@ import ( "cdr.dev/slog/sloggers/slogjson" "cdr.dev/slog/sloggers/slogstackdriver" "github.com/coder/coder/v2/agent" + "github.com/coder/coder/v2/agent/agentproc" "github.com/coder/coder/v2/agent/reaper" "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/cli/clibase" @@ -267,6 +268,8 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { subsystems = append(subsystems, subsystem) } + procTicker := time.NewTicker(time.Second) + defer procTicker.Stop() agnt := agent.New(agent.Options{ Client: client, Logger: logger, @@ -284,13 +287,18 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd { return resp.SessionToken, nil }, EnvironmentVariables: map[string]string{ - "GIT_ASKPASS": executablePath, + "GIT_ASKPASS": executablePath, + agent.EnvProcPrioMgmt: os.Getenv(agent.EnvProcPrioMgmt), }, IgnorePorts: ignorePorts, SSHMaxTimeout: sshMaxTimeout, Subsystems: subsystems, PrometheusRegistry: prometheusRegistry, + Syscaller: agentproc.NewSyscaller(), + // Intentionally set this to nil. It's mainly used + // for testing. + ModifiedProcesses: nil, }) prometheusSrvClose := ServeHandler(ctx, logger, prometheusMetricsHandler(prometheusRegistry, logger), prometheusAddress, "prometheus")