mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
fix(agent): Allow signal propagation when running as PID 1 (#6141)
This commit is contained in:
committed by
GitHub
parent
af59e2bcfa
commit
6f3f7f2937
@ -1,6 +1,10 @@
|
|||||||
package reaper
|
package reaper
|
||||||
|
|
||||||
import "github.com/hashicorp/go-reap"
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-reap"
|
||||||
|
)
|
||||||
|
|
||||||
type Option func(o *options)
|
type Option func(o *options)
|
||||||
|
|
||||||
@ -22,7 +26,16 @@ func WithPIDCallback(ch reap.PidCh) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type options struct {
|
// WithCatchSignals sets the signals that are caught and forwarded to the
|
||||||
ExecArgs []string
|
// child process. By default no signals are forwarded.
|
||||||
PIDs reap.PidCh
|
func WithCatchSignals(sigs ...os.Signal) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.CatchSignals = sigs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
ExecArgs []string
|
||||||
|
PIDs reap.PidCh
|
||||||
|
CatchSignals []os.Signal
|
||||||
}
|
}
|
||||||
|
@ -3,8 +3,11 @@
|
|||||||
package reaper_test
|
package reaper_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -15,9 +18,8 @@ import (
|
|||||||
"github.com/coder/coder/testutil"
|
"github.com/coder/coder/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//nolint:paralleltest // Non-parallel subtest.
|
||||||
func TestReap(t *testing.T) {
|
func TestReap(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// Don't run the reaper test in CI. It does weird
|
// Don't run the reaper test in CI. It does weird
|
||||||
// things like forkexecing which may have unintended
|
// things like forkexecing which may have unintended
|
||||||
// consequences in CI.
|
// consequences in CI.
|
||||||
@ -28,8 +30,9 @@ func TestReap(t *testing.T) {
|
|||||||
// OK checks that's the reaper is successfully reaping
|
// OK checks that's the reaper is successfully reaping
|
||||||
// exited processes and passing the PIDs through the shared
|
// exited processes and passing the PIDs through the shared
|
||||||
// channel.
|
// channel.
|
||||||
|
|
||||||
|
//nolint:paralleltest // Signal handling.
|
||||||
t.Run("OK", func(t *testing.T) {
|
t.Run("OK", func(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
pids := make(reap.PidCh, 1)
|
pids := make(reap.PidCh, 1)
|
||||||
err := reaper.ForkReap(
|
err := reaper.ForkReap(
|
||||||
reaper.WithPIDCallback(pids),
|
reaper.WithPIDCallback(pids),
|
||||||
@ -64,3 +67,39 @@ func TestReap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//nolint:paralleltest // Signal handling.
|
||||||
|
func TestReapInterrupt(t *testing.T) {
|
||||||
|
// Don't run the reaper test in CI. It does weird
|
||||||
|
// things like forkexecing which may have unintended
|
||||||
|
// consequences in CI.
|
||||||
|
if _, ok := os.LookupEnv("CI"); ok {
|
||||||
|
t.Skip("Detected CI, skipping reaper tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
errC := make(chan error, 1)
|
||||||
|
pids := make(reap.PidCh, 1)
|
||||||
|
|
||||||
|
// Use signals to notify when the child process is ready for the
|
||||||
|
// next step of our test.
|
||||||
|
usrSig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(usrSig, syscall.SIGUSR1, syscall.SIGUSR2)
|
||||||
|
defer signal.Stop(usrSig)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errC <- reaper.ForkReap(
|
||||||
|
reaper.WithPIDCallback(pids),
|
||||||
|
reaper.WithCatchSignals(os.Interrupt),
|
||||||
|
// Signal propagation does not extend to children of children, so
|
||||||
|
// we create a little bash script to ensure sleep is interrupted.
|
||||||
|
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.Equal(t, <-usrSig, syscall.SIGUSR1)
|
||||||
|
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, <-usrSig, syscall.SIGUSR2)
|
||||||
|
|
||||||
|
require.NoError(t, <-errC)
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ package reaper
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/hashicorp/go-reap"
|
"github.com/hashicorp/go-reap"
|
||||||
@ -15,6 +16,24 @@ func IsInitProcess() bool {
|
|||||||
return os.Getpid() == 1
|
return os.Getpid() == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func catchSignals(pid int, sigs []os.Signal) {
|
||||||
|
if len(sigs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sc := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sc, sigs...)
|
||||||
|
defer signal.Stop(sc)
|
||||||
|
|
||||||
|
for {
|
||||||
|
s := <-sc
|
||||||
|
sig, ok := s.(syscall.Signal)
|
||||||
|
if ok {
|
||||||
|
_ = syscall.Kill(pid, sig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ForkReap spawns a goroutine that reaps children. In order to avoid
|
// ForkReap spawns a goroutine that reaps children. In order to avoid
|
||||||
// complications with spawning `exec.Commands` in the same process that
|
// complications with spawning `exec.Commands` in the same process that
|
||||||
// is reaping, we forkexec a child process. This prevents a race between
|
// is reaping, we forkexec a child process. This prevents a race between
|
||||||
@ -51,13 +70,17 @@ func ForkReap(opt ...Option) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//#nosec G204
|
//#nosec G204
|
||||||
pid, _ := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
|
pid, err := syscall.ForkExec(opts.ExecArgs[0], opts.ExecArgs, pattrs)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("fork exec: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go catchSignals(pid, opts.CatchSignals)
|
||||||
|
|
||||||
var wstatus syscall.WaitStatus
|
var wstatus syscall.WaitStatus
|
||||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||||
for xerrors.Is(err, syscall.EINTR) {
|
for xerrors.Is(err, syscall.EINTR) {
|
||||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,10 @@ func workspaceAgent() *cobra.Command {
|
|||||||
// Do not start a reaper on the child process. It's important
|
// Do not start a reaper on the child process. It's important
|
||||||
// to do this else we fork bomb ourselves.
|
// to do this else we fork bomb ourselves.
|
||||||
args := append(os.Args, "--no-reap")
|
args := append(os.Args, "--no-reap")
|
||||||
err := reaper.ForkReap(reaper.WithExecArgs(args...))
|
err := reaper.ForkReap(
|
||||||
|
reaper.WithExecArgs(args...),
|
||||||
|
reaper.WithCatchSignals(InterruptSignals...),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "failed to reap", slog.Error(err))
|
logger.Error(ctx, "failed to reap", slog.Error(err))
|
||||||
return xerrors.Errorf("fork reap: %w", err)
|
return xerrors.Errorf("fork reap: %w", err)
|
||||||
|
Reference in New Issue
Block a user