fix: Allow terraform provisions to be gracefully cancelled (#3526)

* fix: Allow terraform provisions to be gracefully cancelled

This change allows terraform commands to be gracefully cancelled on
Unix-like platforms by signaling interrupt on provision cancellation.

One implementation detail to note is that we do not necessarily kill a
running terraform command immediately even if the stream is closed. The
reason for this is to allow for graceful cancellation even in such an
event. Currently the timeout is set to 5 minutes by default.

Related: #2683

The above issue may be partially or fully fixed by this change.

* fix: Remove incorrect minimumTerraformVersion variable

* Allow init to return provision complete response
This commit is contained in:
Mathias Fredriksson
2022-08-18 17:03:55 +03:00
committed by GitHub
parent 6a0f8ae9cc
commit f1423450bd
6 changed files with 357 additions and 84 deletions

View File

@ -41,7 +41,7 @@ func (e executor) basicEnv() []string {
return env
}
func (e executor) execWriteOutput(ctx context.Context, args, env []string, stdOutWriter, stdErrWriter io.WriteCloser) (err error) {
func (e executor) execWriteOutput(ctx, killCtx context.Context, args, env []string, stdOutWriter, stdErrWriter io.WriteCloser) (err error) {
defer func() {
closeErr := stdOutWriter.Close()
if err == nil && closeErr != nil {
@ -52,8 +52,12 @@ func (e executor) execWriteOutput(ctx context.Context, args, env []string, stdOu
err = closeErr
}
}()
if ctx.Err() != nil {
return ctx.Err()
}
// #nosec
cmd := exec.CommandContext(ctx, e.binaryPath, args...)
cmd := exec.CommandContext(killCtx, e.binaryPath, args...)
cmd.Dir = e.workdir
cmd.Env = env
@ -63,19 +67,36 @@ func (e executor) execWriteOutput(ctx context.Context, args, env []string, stdOu
cmd.Stdout = syncWriter{mut, stdOutWriter}
cmd.Stderr = syncWriter{mut, stdErrWriter}
return cmd.Run()
err = cmd.Start()
if err != nil {
return err
}
interruptCommandOnCancel(ctx, killCtx, cmd)
return cmd.Wait()
}
func (e executor) execParseJSON(ctx context.Context, args, env []string, v interface{}) error {
func (e executor) execParseJSON(ctx, killCtx context.Context, args, env []string, v interface{}) error {
if ctx.Err() != nil {
return ctx.Err()
}
// #nosec
cmd := exec.CommandContext(ctx, e.binaryPath, args...)
cmd := exec.CommandContext(killCtx, e.binaryPath, args...)
cmd.Dir = e.workdir
cmd.Env = env
out := &bytes.Buffer{}
stdErr := &bytes.Buffer{}
cmd.Stdout = out
cmd.Stderr = stdErr
err := cmd.Run()
err := cmd.Start()
if err != nil {
return err
}
interruptCommandOnCancel(ctx, killCtx, cmd)
err = cmd.Wait()
if err != nil {
errString, _ := io.ReadAll(stdErr)
return xerrors.Errorf("%s: %w", errString, err)
@ -95,11 +116,11 @@ func (e executor) checkMinVersion(ctx context.Context) error {
if err != nil {
return err
}
if !v.GreaterThanOrEqual(minimumTerraformVersion) {
if !v.GreaterThanOrEqual(minTerraformVersion) {
return xerrors.Errorf(
"terraform version %q is too old. required >= %q",
v.String(),
minimumTerraformVersion.String())
minTerraformVersion.String())
}
return nil
}
@ -109,6 +130,10 @@ func (e executor) version(ctx context.Context) (*version.Version, error) {
}
func versionFromBinaryPath(ctx context.Context, binaryPath string) (*version.Version, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
// #nosec
cmd := exec.CommandContext(ctx, binaryPath, "version", "-json")
out, err := cmd.Output()
@ -130,7 +155,7 @@ func versionFromBinaryPath(ctx context.Context, binaryPath string) (*version.Ver
return version.NewVersion(vj.Version)
}
func (e executor) init(ctx context.Context, logr logger) error {
func (e executor) init(ctx, killCtx context.Context, logr logger) error {
outWriter, doneOut := logWriter(logr, proto.LogLevel_DEBUG)
errWriter, doneErr := logWriter(logr, proto.LogLevel_ERROR)
defer func() {
@ -156,11 +181,11 @@ func (e executor) init(ctx context.Context, logr logger) error {
defer e.initMu.Unlock()
}
return e.execWriteOutput(ctx, args, e.basicEnv(), outWriter, errWriter)
return e.execWriteOutput(ctx, killCtx, args, e.basicEnv(), outWriter, errWriter)
}
// revive:disable-next-line:flag-parameter
func (e executor) plan(ctx context.Context, env, vars []string, logr logger, destroy bool) (*proto.Provision_Response, error) {
func (e executor) plan(ctx, killCtx context.Context, env, vars []string, logr logger, destroy bool) (*proto.Provision_Response, error) {
planfilePath := filepath.Join(e.workdir, "terraform.tfplan")
args := []string{
"plan",
@ -184,11 +209,11 @@ func (e executor) plan(ctx context.Context, env, vars []string, logr logger, des
<-doneErr
}()
err := e.execWriteOutput(ctx, args, env, outWriter, errWriter)
err := e.execWriteOutput(ctx, killCtx, args, env, outWriter, errWriter)
if err != nil {
return nil, xerrors.Errorf("terraform plan: %w", err)
}
resources, err := e.planResources(ctx, planfilePath)
resources, err := e.planResources(ctx, killCtx, planfilePath)
if err != nil {
return nil, err
}
@ -201,40 +226,52 @@ func (e executor) plan(ctx context.Context, env, vars []string, logr logger, des
}, nil
}
func (e executor) planResources(ctx context.Context, planfilePath string) ([]*proto.Resource, error) {
plan, err := e.showPlan(ctx, planfilePath)
func (e executor) planResources(ctx, killCtx context.Context, planfilePath string) ([]*proto.Resource, error) {
plan, err := e.showPlan(ctx, killCtx, planfilePath)
if err != nil {
return nil, xerrors.Errorf("show terraform plan file: %w", err)
}
rawGraph, err := e.graph(ctx)
rawGraph, err := e.graph(ctx, killCtx)
if err != nil {
return nil, xerrors.Errorf("graph: %w", err)
}
return ConvertResources(plan.PlannedValues.RootModule, rawGraph)
}
func (e executor) showPlan(ctx context.Context, planfilePath string) (*tfjson.Plan, error) {
func (e executor) showPlan(ctx, killCtx context.Context, planfilePath string) (*tfjson.Plan, error) {
args := []string{"show", "-json", "-no-color", planfilePath}
p := new(tfjson.Plan)
err := e.execParseJSON(ctx, args, e.basicEnv(), p)
err := e.execParseJSON(ctx, killCtx, args, e.basicEnv(), p)
return p, err
}
func (e executor) graph(ctx context.Context) (string, error) {
// #nosec
cmd := exec.CommandContext(ctx, e.binaryPath, "graph")
func (e executor) graph(ctx, killCtx context.Context) (string, error) {
if ctx.Err() != nil {
return "", ctx.Err()
}
var out bytes.Buffer
cmd := exec.CommandContext(killCtx, e.binaryPath, "graph") // #nosec
cmd.Stdout = &out
cmd.Dir = e.workdir
cmd.Env = e.basicEnv()
out, err := cmd.Output()
err := cmd.Start()
if err != nil {
return "", err
}
interruptCommandOnCancel(ctx, killCtx, cmd)
err = cmd.Wait()
if err != nil {
return "", xerrors.Errorf("graph: %w", err)
}
return string(out), nil
return out.String(), nil
}
// revive:disable-next-line:flag-parameter
func (e executor) apply(ctx context.Context, env, vars []string, logr logger, destroy bool,
func (e executor) apply(ctx, killCtx context.Context, env, vars []string, logr logger, destroy bool,
) (*proto.Provision_Response, error) {
args := []string{
"apply",
@ -258,11 +295,11 @@ func (e executor) apply(ctx context.Context, env, vars []string, logr logger, de
<-doneErr
}()
err := e.execWriteOutput(ctx, args, env, outWriter, errWriter)
err := e.execWriteOutput(ctx, killCtx, args, env, outWriter, errWriter)
if err != nil {
return nil, xerrors.Errorf("terraform apply: %w", err)
}
resources, err := e.stateResources(ctx)
resources, err := e.stateResources(ctx, killCtx)
if err != nil {
return nil, err
}
@ -281,12 +318,12 @@ func (e executor) apply(ctx context.Context, env, vars []string, logr logger, de
}, nil
}
func (e executor) stateResources(ctx context.Context) ([]*proto.Resource, error) {
state, err := e.state(ctx)
func (e executor) stateResources(ctx, killCtx context.Context) ([]*proto.Resource, error) {
state, err := e.state(ctx, killCtx)
if err != nil {
return nil, err
}
rawGraph, err := e.graph(ctx)
rawGraph, err := e.graph(ctx, killCtx)
if err != nil {
return nil, xerrors.Errorf("get terraform graph: %w", err)
}
@ -300,16 +337,33 @@ func (e executor) stateResources(ctx context.Context) ([]*proto.Resource, error)
return resources, nil
}
func (e executor) state(ctx context.Context) (*tfjson.State, error) {
func (e executor) state(ctx, killCtx context.Context) (*tfjson.State, error) {
args := []string{"show", "-json"}
state := &tfjson.State{}
err := e.execParseJSON(ctx, args, e.basicEnv(), state)
err := e.execParseJSON(ctx, killCtx, args, e.basicEnv(), state)
if err != nil {
return nil, xerrors.Errorf("terraform show state: %w", err)
}
return state, nil
}
func interruptCommandOnCancel(ctx, killCtx context.Context, cmd *exec.Cmd) {
go func() {
select {
case <-ctx.Done():
switch runtime.GOOS {
case "windows":
// Interrupts aren't supported by Windows.
_ = cmd.Process.Kill()
default:
_ = cmd.Process.Signal(os.Interrupt)
}
case <-killCtx.Done():
}
}()
}
type logger interface {
Log(*proto.Log) error
}
@ -381,9 +435,6 @@ func provisionReadAndLog(logr logger, reader io.Reader, done chan<- any) {
// If the diagnostic is provided, let's provide a bit more info!
logLevel = convertTerraformLogLevel(log.Diagnostic.Severity, logr)
if err != nil {
continue
}
err = logr.Log(&proto.Log{Level: logLevel, Output: log.Diagnostic.Detail})
if err != nil {
// Not much we can do. We can't log because logging is itself breaking!

View File

@ -16,7 +16,7 @@ import (
func TestParse(t *testing.T) {
t.Parallel()
ctx, api := setupProvisioner(t)
ctx, api := setupProvisioner(t, nil)
testCases := []struct {
Name string
@ -171,7 +171,7 @@ func TestParse(t *testing.T) {
// Write all files to the temporary test directory.
directory := t.TempDir()
for path, content := range testCase.Files {
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600)
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0o600)
require.NoError(t, err)
}

View File

@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"strings"
"time"
"golang.org/x/xerrors"
@ -14,11 +15,7 @@ import (
)
// Provision executes `terraform apply` or `terraform plan` for dry runs.
func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
logr := streamLogger{stream: stream}
shutdown, shutdownFunc := context.WithCancel(stream.Context())
defer shutdownFunc()
func (s *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
request, err := stream.Recv()
if err != nil {
return err
@ -30,6 +27,33 @@ func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
if request.GetStart() == nil {
return nil
}
// Create a context for graceful cancellation bound to the stream
// context. This ensures that we will perform graceful cancellation
// even on connection loss.
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
// Create a separate context for forcefull cancellation not tied to
// the stream so that we can control when to terminate the process.
killCtx, kill := context.WithCancel(context.Background())
defer kill()
// Ensure processes are eventually cleaned up on graceful
// cancellation or disconnect.
go func() {
<-stream.Context().Done()
// TODO(mafredri): We should track this provision request as
// part of graceful server shutdown procedure. Waiting on a
// process here should delay provisioner/coder shutdown.
select {
case <-time.After(s.exitTimeout):
kill()
case <-killCtx.Done():
}
}()
go func() {
for {
request, err := stream.Recv()
@ -37,29 +61,28 @@ func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
return
}
if request.GetCancel() == nil {
// This is only to process cancels!
// We only process cancellation requests here.
continue
}
shutdownFunc()
cancel()
return
}
}()
logr := streamLogger{stream: stream}
start := request.GetStart()
if err != nil {
return xerrors.Errorf("create new terraform executor: %w", err)
}
e := t.executor(start.Directory)
if err := e.checkMinVersion(stream.Context()); err != nil {
e := s.executor(start.Directory)
if err = e.checkMinVersion(ctx); err != nil {
return err
}
if err := logTerraformEnvVars(logr); err != nil {
if err = logTerraformEnvVars(logr); err != nil {
return err
}
statefilePath := filepath.Join(start.Directory, "terraform.tfstate")
if len(start.State) > 0 {
err := os.WriteFile(statefilePath, start.State, 0600)
err = os.WriteFile(statefilePath, start.State, 0o600)
if err != nil {
return xerrors.Errorf("write statefile %q: %w", statefilePath, err)
}
@ -87,12 +110,21 @@ func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
})
}
t.logger.Debug(shutdown, "running initialization")
err = e.init(stream.Context(), logr)
s.logger.Debug(ctx, "running initialization")
err = e.init(ctx, killCtx, logr)
if err != nil {
if ctx.Err() != nil {
return stream.Send(&proto.Provision_Response{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Error: err.Error(),
},
},
})
}
return xerrors.Errorf("initialize terraform: %w", err)
}
t.logger.Debug(shutdown, "ran initialization")
s.logger.Debug(ctx, "ran initialization")
env, err := provisionEnv(start)
if err != nil {
@ -104,15 +136,15 @@ func (t *server) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
}
var resp *proto.Provision_Response
if start.DryRun {
resp, err = e.plan(shutdown, env, vars, logr,
resp, err = e.plan(ctx, killCtx, env, vars, logr,
start.Metadata.WorkspaceTransition == proto.WorkspaceTransition_DESTROY)
} else {
resp, err = e.apply(shutdown, env, vars, logr,
resp, err = e.apply(ctx, killCtx, env, vars, logr,
start.Metadata.WorkspaceTransition == proto.WorkspaceTransition_DESTROY)
}
if err != nil {
if start.DryRun {
if shutdown.Err() != nil {
if ctx.Err() != nil {
return stream.Send(&proto.Provision_Response{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{

View File

@ -5,11 +5,14 @@ package terraform_test
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -22,7 +25,15 @@ import (
"github.com/coder/coder/provisionersdk/proto"
)
func setupProvisioner(t *testing.T) (context.Context, proto.DRPCProvisionerClient) {
type provisionerServeOptions struct {
binaryPath string
exitTimeout time.Duration
}
func setupProvisioner(t *testing.T, opts *provisionerServeOptions) (context.Context, proto.DRPCProvisionerClient) {
if opts == nil {
opts = &provisionerServeOptions{}
}
cachePath := t.TempDir()
client, server := provisionersdk.TransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
@ -39,18 +50,125 @@ func setupProvisioner(t *testing.T) (context.Context, proto.DRPCProvisionerClien
ServeOptions: &provisionersdk.ServeOptions{
Listener: server,
},
BinaryPath: opts.binaryPath,
CachePath: cachePath,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
ExitTimeout: opts.exitTimeout,
})
}()
api := proto.NewDRPCProvisionerClient(provisionersdk.Conn(client))
return ctx, api
}
func TestProvision_Cancel(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("This test uses interrupts and is not supported on Windows")
}
t.Parallel()
cwd, err := os.Getwd()
require.NoError(t, err)
fakeBin := filepath.Join(cwd, "testdata", "bin", "terraform_fake_cancel.sh")
tests := []struct {
name string
mode string
startSequence []string
wantLog []string
}{
{
name: "Cancel init",
mode: "init",
startSequence: []string{"init_start"},
wantLog: []string{"interrupt", "exit"},
},
{
name: "Cancel apply",
mode: "apply",
startSequence: []string{"init", "apply_start"},
wantLog: []string{"interrupt", "exit"},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
binPath := filepath.Join(dir, "terraform")
// Example: exec /path/to/terrafork_fake_cancel.sh 1.2.1 apply "$@"
content := fmt.Sprintf("#!/bin/sh\nexec %q %s %s \"$@\"\n", fakeBin, terraform.TerraformVersion.String(), tt.mode)
err = os.WriteFile(binPath, []byte(content), 0o755) //#nosec
require.NoError(t, err)
ctx, api := setupProvisioner(t, &provisionerServeOptions{
binaryPath: binPath,
exitTimeout: time.Nanosecond,
})
response, err := api.Provision(ctx)
require.NoError(t, err)
err = response.Send(&proto.Provision_Request{
Type: &proto.Provision_Request_Start{
Start: &proto.Provision_Start{
Directory: dir,
DryRun: false,
ParameterValues: []*proto.ParameterValue{{
DestinationScheme: proto.ParameterDestination_PROVISIONER_VARIABLE,
Name: "A",
Value: "example",
}},
Metadata: &proto.Provision_Metadata{},
},
},
})
require.NoError(t, err)
for _, line := range tt.startSequence {
LoopStart:
msg, err := response.Recv()
require.NoError(t, err)
t.Log(msg.Type)
log := msg.GetLog()
if log == nil {
goto LoopStart
}
require.Equal(t, line, log.Output)
}
err = response.Send(&proto.Provision_Request{
Type: &proto.Provision_Request_Cancel{
Cancel: &proto.Provision_Cancel{},
},
})
require.NoError(t, err)
var gotLog []string
for {
msg, err := response.Recv()
require.NoError(t, err)
if log := msg.GetLog(); log != nil {
gotLog = append(gotLog, log.Output)
}
if c := msg.GetComplete(); c != nil {
require.Contains(t, c.Error, "exit status 1")
break
}
}
require.Equal(t, tt.wantLog, gotLog)
})
}
}
func TestProvision(t *testing.T) {
t.Parallel()
ctx, api := setupProvisioner(t)
ctx, api := setupProvisioner(t, nil)
testCases := []struct {
Name string
@ -209,7 +327,7 @@ func TestProvision(t *testing.T) {
directory := t.TempDir()
for path, content := range testCase.Files {
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600)
err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0o600)
require.NoError(t, err)
}
@ -302,11 +420,11 @@ func TestProvision_ExtraEnv(t *testing.T) {
t.Setenv("TF_LOG", "INFO")
t.Setenv("TF_SUPERSECRET", secretValue)
ctx, api := setupProvisioner(t)
ctx, api := setupProvisioner(t, nil)
directory := t.TempDir()
path := filepath.Join(directory, "main.tf")
err := os.WriteFile(path, []byte(`resource "null_resource" "A" {}`), 0600)
err := os.WriteFile(path, []byte(`resource "null_resource" "A" {}`), 0o600)
require.NoError(t, err)
request := &proto.Provision_Request{

View File

@ -4,6 +4,7 @@ import (
"context"
"path/filepath"
"sync"
"time"
"github.com/cli/safeexec"
"github.com/hashicorp/go-version"
@ -15,26 +16,20 @@ import (
"github.com/coder/coder/provisionersdk"
)
// This is the exact version of Terraform used internally
// when Terraform is missing on the system.
var terraformVersion = version.Must(version.NewVersion("1.2.1"))
var minTerraformVersion = version.Must(version.NewVersion("1.1.0"))
var maxTerraformVersion = version.Must(version.NewVersion("1.2.1"))
var (
// The minimum version of Terraform supported by the provisioner.
// Validation came out in 0.13.0, which was released August 10th, 2020.
// https://www.hashicorp.com/blog/announcing-hashicorp-terraform-0-13
minimumTerraformVersion = func() *version.Version {
v, err := version.NewSemver("0.13.0")
if err != nil {
panic(err)
}
return v
}()
// TerraformVersion is the version of Terraform used internally
// when Terraform is not available on the system.
TerraformVersion = version.Must(version.NewVersion("1.2.1"))
minTerraformVersion = version.Must(version.NewVersion("1.1.0"))
maxTerraformVersion = version.Must(version.NewVersion("1.2.1"))
terraformMinorVersionMismatch = xerrors.New("Terraform binary minor version mismatch.")
)
var terraformMinorVersionMismatch = xerrors.New("Terraform binary minor version mismatch.")
const (
defaultExitTimeout = 5 * time.Minute
)
type ServeOptions struct {
*provisionersdk.ServeOptions
@ -44,6 +39,17 @@ type ServeOptions struct {
BinaryPath string
CachePath string
Logger slog.Logger
// ExitTimeout defines how long we will wait for a running Terraform
// command to exit (cleanly) if the provision was stopped. This only
// happens when the command is still running after the provision
// stream is closed. If the provision is canceled via RPC, this
// timeout will not be used.
//
// This is a no-op on Windows where the process can't be interrupted.
//
// Default value: 5 minutes.
ExitTimeout time.Duration
}
func absoluteBinaryPath(ctx context.Context) (string, error) {
@ -90,7 +96,7 @@ func Serve(ctx context.Context, options *ServeOptions) error {
installer := &releases.ExactVersion{
InstallDir: options.CachePath,
Product: product.Terraform,
Version: terraformVersion,
Version: TerraformVersion,
}
execPath, err := installer.Install(ctx)
@ -102,10 +108,14 @@ func Serve(ctx context.Context, options *ServeOptions) error {
options.BinaryPath = absoluteBinary
}
}
if options.ExitTimeout == 0 {
options.ExitTimeout = defaultExitTimeout
}
return provisionersdk.Serve(ctx, &server{
binaryPath: options.BinaryPath,
cachePath: options.CachePath,
logger: options.Logger,
exitTimeout: options.ExitTimeout,
}, options.ServeOptions)
}
@ -117,6 +127,8 @@ type server struct {
binaryPath string
cachePath string
logger slog.Logger
exitTimeout time.Duration
}
func (s *server) executor(workdir string) executor {

View File

@ -0,0 +1,60 @@
#!/bin/sh
VERSION=$1
MODE=$2
shift 2
json_print() {
echo "{\"@level\":\"error\",\"@message\":\"$*\"}"
}
case "$1" in
version)
cat <<-EOF
{
"terraform_version": "${VERSION}",
"platform": "linux_amd64",
"provider_selections": {},
"terraform_outdated": false
}
EOF
exit 0
;;
init)
case "$MODE" in
apply)
echo "init"
;;
init)
sleep 10 &
sleep_pid=$!
trap 'echo exit; kill -9 $sleep_pid 2>/dev/null' EXIT
trap 'echo interrupt; exit 1' INT
trap 'echo terminate"; exit 2' TERM
echo init_start
wait
echo init_end
;;
esac
;;
apply)
sleep 10 &
sleep_pid=$!
trap 'json_print exit; kill -9 $sleep_pid 2>/dev/null' EXIT
trap 'json_print interrupt; exit 1' INT
trap 'json_print terminate"; exit 2' TERM
json_print apply_start
wait
json_print apply_end
;;
plan)
echo "plan not supported"
exit 1
;;
esac
exit 0