fix: Improve coder server shutdown procedure (#3246)

* fix: Improve `coder server` shutdown procedure

This commit improves the `coder server` shutdown procedure so that all
triggers for shutdown do so in a graceful way without skipping any
steps.

We also improve cancellation and shutdown of services by ensuring
resources are cleaned up at the end.

Notable changes:
- We wrap `cmd.Context()` to allow us to control cancellation better
- We attempt graceful shutdown of the http server (`server.Shutdown`)
  because it's less abrupt (compared to `shutdownConns`)
- All exit paths share the same shutdown procedure (except for early
  exit)
- `provisionerd`s are now shutdown concurrently instead of one at a
  time, the also now get a new context for shutdown because
  `cmd.Context()` may be cancelled
- Resources created by `newProvisionerDaemon` are cleaned up
- Lifecycle `Executor` exits its goroutine on context cancellation

Fixes #3245
This commit is contained in:
Mathias Fredriksson
2022-07-27 18:21:21 +03:00
committed by GitHub
parent bb05b1f749
commit d27076cac7
5 changed files with 260 additions and 128 deletions

View File

@ -20,6 +20,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/coreos/go-systemd/daemon" "github.com/coreos/go-systemd/daemon"
@ -111,26 +112,34 @@ func server() *cobra.Command {
logger = logger.Leveled(slog.LevelDebug) logger = logger.Leveled(slog.LevelDebug)
} }
// Main command context for managing cancellation
// of running services.
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
// Clean up idle connections at the end, e.g.
// embedded-postgres can leave an idle connection
// which is caught by goleaks.
defer http.DefaultClient.CloseIdleConnections()
var ( var (
tracerProvider *sdktrace.TracerProvider tracerProvider *sdktrace.TracerProvider
err error err error
sqlDriver = "postgres" sqlDriver = "postgres"
) )
if trace { if trace {
tracerProvider, err = tracing.TracerProvider(cmd.Context(), "coderd") tracerProvider, err = tracing.TracerProvider(ctx, "coderd")
if err != nil { if err != nil {
logger.Warn(cmd.Context(), "failed to start telemetry exporter", slog.Error(err)) logger.Warn(ctx, "failed to start telemetry exporter", slog.Error(err))
} else { } else {
// allow time for traces to flush even if command context is canceled
defer func() { defer func() {
// allow time for traces to flush even if command context is canceled _ = shutdownWithTimeout(tracerProvider, 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = tracerProvider.Shutdown(ctx)
}() }()
d, err := tracing.PostgresDriver(tracerProvider, "coderd.database") d, err := tracing.PostgresDriver(tracerProvider, "coderd.database")
if err != nil { if err != nil {
logger.Warn(cmd.Context(), "failed to start postgres tracing driver", slog.Error(err)) logger.Warn(ctx, "failed to start postgres tracing driver", slog.Error(err))
} else { } else {
sqlDriver = d sqlDriver = d
} }
@ -143,14 +152,16 @@ func server() *cobra.Command {
if !inMemoryDatabase && postgresURL == "" { if !inMemoryDatabase && postgresURL == "" {
var closeFunc func() error var closeFunc func() error
cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath()) cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath())
postgresURL, closeFunc, err = startBuiltinPostgres(cmd.Context(), config, logger) postgresURL, closeFunc, err = startBuiltinPostgres(ctx, config, logger)
if err != nil { if err != nil {
return err return err
} }
builtinPostgres = true builtinPostgres = true
defer func() { defer func() {
cmd.Printf("Stopping built-in PostgreSQL...\n")
// Gracefully shut PostgreSQL down! // Gracefully shut PostgreSQL down!
_ = closeFunc() _ = closeFunc()
cmd.Printf("Stopped built-in PostgreSQL\n")
}() }()
} }
@ -189,9 +200,9 @@ func server() *cobra.Command {
} }
var ( var (
ctxTunnel, closeTunnel = context.WithCancel(cmd.Context()) ctxTunnel, closeTunnel = context.WithCancel(ctx)
devTunnel = (*devtunnel.Tunnel)(nil) devTunnel *devtunnel.Tunnel
devTunnelErrChan = make(<-chan error, 1) devTunnelErr <-chan error
) )
defer closeTunnel() defer closeTunnel()
@ -199,7 +210,7 @@ func server() *cobra.Command {
// needs to be changed to use the tunnel. // needs to be changed to use the tunnel.
if tunnel { if tunnel {
cmd.Printf("Opening tunnel so workspaces can connect to your deployment\n") cmd.Printf("Opening tunnel so workspaces can connect to your deployment\n")
devTunnel, devTunnelErrChan, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel")) devTunnel, devTunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
if err != nil { if err != nil {
return xerrors.Errorf("create tunnel: %w", err) return xerrors.Errorf("create tunnel: %w", err)
} }
@ -207,7 +218,7 @@ func server() *cobra.Command {
} }
// Warn the user if the access URL appears to be a loopback address. // Warn the user if the access URL appears to be a loopback address.
isLocal, err := isLocalURL(cmd.Context(), accessURL) isLocal, err := isLocalURL(ctx, accessURL)
if isLocal || err != nil { if isLocal || err != nil {
reason := "could not be resolved" reason := "could not be resolved"
if isLocal { if isLocal {
@ -224,7 +235,7 @@ func server() *cobra.Command {
} }
// Used for zero-trust instance identity with Google Cloud. // Used for zero-trust instance identity with Google Cloud.
googleTokenValidator, err := idtoken.NewValidator(cmd.Context(), option.WithoutAuthentication()) googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication())
if err != nil { if err != nil {
return err return err
} }
@ -241,6 +252,7 @@ func server() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("create turn server: %w", err) return xerrors.Errorf("create turn server: %w", err)
} }
defer turnServer.Close()
iceServers := make([]webrtc.ICEServer, 0) iceServers := make([]webrtc.ICEServer, 0)
for _, stunServer := range stunServers { for _, stunServer := range stunServers {
@ -278,6 +290,8 @@ func server() *cobra.Command {
if err != nil { if err != nil {
return xerrors.Errorf("dial postgres: %w", err) return xerrors.Errorf("dial postgres: %w", err)
} }
defer sqlDB.Close()
err = sqlDB.Ping() err = sqlDB.Ping()
if err != nil { if err != nil {
return xerrors.Errorf("ping postgres: %w", err) return xerrors.Errorf("ping postgres: %w", err)
@ -287,13 +301,14 @@ func server() *cobra.Command {
return xerrors.Errorf("migrate up: %w", err) return xerrors.Errorf("migrate up: %w", err)
} }
options.Database = database.New(sqlDB) options.Database = database.New(sqlDB)
options.Pubsub, err = database.NewPubsub(cmd.Context(), sqlDB, postgresURL) options.Pubsub, err = database.NewPubsub(ctx, sqlDB, postgresURL)
if err != nil { if err != nil {
return xerrors.Errorf("create pubsub: %w", err) return xerrors.Errorf("create pubsub: %w", err)
} }
defer options.Pubsub.Close()
} }
deploymentID, err := options.Database.GetDeploymentID(cmd.Context()) deploymentID, err := options.Database.GetDeploymentID(ctx)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
err = nil err = nil
} }
@ -302,7 +317,7 @@ func server() *cobra.Command {
} }
if deploymentID == "" { if deploymentID == "" {
deploymentID = uuid.NewString() deploymentID = uuid.NewString()
err = options.Database.InsertDeploymentID(cmd.Context(), deploymentID) err = options.Database.InsertDeploymentID(ctx, deploymentID)
if err != nil { if err != nil {
return xerrors.Errorf("set deployment id: %w", err) return xerrors.Errorf("set deployment id: %w", err)
} }
@ -336,6 +351,8 @@ func server() *cobra.Command {
} }
coderAPI := coderd.New(options) coderAPI := coderd.New(options)
defer coderAPI.Close()
client := codersdk.New(localURL) client := codersdk.New(localURL)
if tlsEnable { if tlsEnable {
// Secure transport isn't needed for locally communicating! // Secure transport isn't needed for locally communicating!
@ -351,64 +368,75 @@ func server() *cobra.Command {
_ = pprof.Handler _ = pprof.Handler
if pprofEnabled { if pprofEnabled {
//nolint:revive //nolint:revive
defer serveHandler(cmd.Context(), logger, nil, pprofAddress, "pprof")() defer serveHandler(ctx, logger, nil, pprofAddress, "pprof")()
} }
if promEnabled { if promEnabled {
//nolint:revive //nolint:revive
defer serveHandler(cmd.Context(), logger, promhttp.Handler(), promAddress, "prometheus")() defer serveHandler(ctx, logger, promhttp.Handler(), promAddress, "prometheus")()
} }
// Since errCh only has one buffered slot, all routines
// sending on it must be wrapped in a select/default to
// avoid leaving dangling goroutines waiting for the
// channel to be consumed.
errCh := make(chan error, 1) errCh := make(chan error, 1)
provisionerDaemons := make([]*provisionerd.Server, 0) provisionerDaemons := make([]*provisionerd.Server, 0)
defer func() {
// We have no graceful shutdown of provisionerDaemons
// here because that's handled at the end of main, this
// is here in case the program exits early.
for _, daemon := range provisionerDaemons {
_ = daemon.Close()
}
}()
for i := 0; uint8(i) < provisionerDaemonCount; i++ { for i := 0; uint8(i) < provisionerDaemonCount; i++ {
daemonClose, err := newProvisionerDaemon(cmd.Context(), coderAPI, logger, cacheDir, errCh, false) daemon, err := newProvisionerDaemon(ctx, coderAPI, logger, cacheDir, errCh, false)
if err != nil { if err != nil {
return xerrors.Errorf("create provisioner daemon: %w", err) return xerrors.Errorf("create provisioner daemon: %w", err)
} }
provisionerDaemons = append(provisionerDaemons, daemonClose) provisionerDaemons = append(provisionerDaemons, daemon)
}
shutdownConnsCtx, shutdownConns := context.WithCancel(ctx)
defer shutdownConns()
server := &http.Server{
// These errors are typically noise like "TLS: EOF". Vault does similar:
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
ErrorLog: log.New(io.Discard, "", 0),
Handler: coderAPI.Handler,
BaseContext: func(_ net.Listener) context.Context {
return shutdownConnsCtx
},
} }
defer func() { defer func() {
for _, provisionerDaemon := range provisionerDaemons { _ = shutdownWithTimeout(server, 5*time.Second)
_ = provisionerDaemon.Close()
}
}() }()
shutdownConnsCtx, shutdownConns := context.WithCancel(cmd.Context()) eg := errgroup.Group{}
defer shutdownConns() eg.Go(func() error {
go func() { // Make sure to close the tunnel listener if we exit so the
server := http.Server{ // errgroup doesn't wait forever!
// These errors are typically noise like "TLS: EOF". Vault does similar:
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
ErrorLog: log.New(io.Discard, "", 0),
Handler: coderAPI.Handler,
BaseContext: func(_ net.Listener) context.Context {
return shutdownConnsCtx
},
}
wg := errgroup.Group{}
wg.Go(func() error {
// Make sure to close the tunnel listener if we exit so the
// errgroup doesn't wait forever!
if tunnel {
defer devTunnel.Listener.Close()
}
return server.Serve(listener)
})
if tunnel { if tunnel {
wg.Go(func() error { defer devTunnel.Listener.Close()
defer listener.Close()
return server.Serve(devTunnel.Listener)
})
} }
errCh <- wg.Wait() return server.Serve(listener)
})
if tunnel {
eg.Go(func() error {
defer listener.Close()
return server.Serve(devTunnel.Listener)
})
}
go func() {
select {
case errCh <- eg.Wait():
default:
}
}() }()
hasFirstUser, err := client.HasFirstUser(cmd.Context()) hasFirstUser, err := client.HasFirstUser(ctx)
if !hasFirstUser && err == nil { if !hasFirstUser && err == nil {
cmd.Println() cmd.Println()
cmd.Println("Get started by creating the first user (in a new terminal):") cmd.Println("Get started by creating the first user (in a new terminal):")
@ -425,75 +453,117 @@ func server() *cobra.Command {
autobuildPoller := time.NewTicker(autobuildPollInterval) autobuildPoller := time.NewTicker(autobuildPollInterval)
defer autobuildPoller.Stop() defer autobuildPoller.Stop()
autobuildExecutor := executor.New(cmd.Context(), options.Database, logger, autobuildPoller.C) autobuildExecutor := executor.New(ctx, options.Database, logger, autobuildPoller.C)
autobuildExecutor.Run() autobuildExecutor.Run()
// This is helpful for tests, but can be silently ignored.
// Coder may be ran as users that don't have permission to write in the homedir,
// such as via the systemd service.
_ = config.URL().Write(client.URL.String())
// Because the graceful shutdown includes cleaning up workspaces in dev mode, we're // Because the graceful shutdown includes cleaning up workspaces in dev mode, we're
// going to make it harder to accidentally skip the graceful shutdown by hitting ctrl+c // going to make it harder to accidentally skip the graceful shutdown by hitting ctrl+c
// two or more times. So the stopChan is unlimited in size and we don't call // two or more times. So the stopChan is unlimited in size and we don't call
// signal.Stop() until graceful shutdown finished--this means we swallow additional // signal.Stop() until graceful shutdown finished--this means we swallow additional
// SIGINT after the first. To get out of a graceful shutdown, the user can send SIGQUIT // SIGINT after the first. To get out of a graceful shutdown, the user can send SIGQUIT
// with ctrl+\ or SIGTERM with `kill`. // with ctrl+\ or SIGTERM with `kill`.
stopChan := make(chan os.Signal, 1) ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
defer signal.Stop(stopChan) defer stop()
signal.Notify(stopChan, os.Interrupt)
// This is helpful for tests, but can be silently ignored.
// Coder may be ran as users that don't have permission to write in the homedir,
// such as via the systemd service.
_ = config.URL().Write(client.URL.String())
// Currently there is no way to ask the server to shut
// itself down, so any exit signal will result in a non-zero
// exit of the server.
var exitErr error
select { select {
case <-cmd.Context().Done(): case <-ctx.Done():
coderAPI.Close() exitErr = ctx.Err()
return cmd.Context().Err() _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
case err := <-devTunnelErrChan: "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit",
if err != nil { ))
return err case exitErr = <-devTunnelErr:
if exitErr == nil {
exitErr = xerrors.New("dev tunnel closed unexpectedly")
} }
case err := <-errCh: case exitErr = <-errCh:
shutdownConns()
coderAPI.Close()
return err
case <-stopChan:
} }
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr)
}
// Begin clean shut down stage, we try to shut down services
// gracefully in an order that gives the best experience.
// This procedure should not differ greatly from the order
// of `defer`s in this function, but allows us to inform
// the user about what's going on and handle errors more
// explicitly.
_, err = daemon.SdNotify(false, daemon.SdNotifyStopping) _, err = daemon.SdNotify(false, daemon.SdNotifyStopping)
if err != nil { if err != nil {
return xerrors.Errorf("notify systemd: %w", err) cmd.Printf("Notify systemd failed: %s", err)
}
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit"))
for _, provisionerDaemon := range provisionerDaemons {
if verbose {
cmd.Println("Shutting down provisioner daemon...")
}
err = provisionerDaemon.Shutdown(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to shutdown provisioner daemon: %s\n", err)
continue
}
err = provisionerDaemon.Close()
if err != nil {
return xerrors.Errorf("close provisioner daemon: %w", err)
}
if verbose {
cmd.Println("Gracefully shut down provisioner daemon!")
}
} }
// Stop accepting new connections without interrupting
// in-flight requests, give in-flight requests 5 seconds to
// complete.
cmd.Println("Shutting down API server...")
err = shutdownWithTimeout(server, 5*time.Second)
if err != nil {
cmd.Printf("API server shutdown took longer than 5s: %s", err)
} else {
cmd.Printf("Gracefully shut down API server\n")
}
// Cancel any remaining in-flight requests.
shutdownConns()
// Shut down provisioners before waiting for WebSockets
// connections to close.
var wg sync.WaitGroup
for i, provisionerDaemon := range provisionerDaemons {
id := i + 1
provisionerDaemon := provisionerDaemon
wg.Add(1)
go func() {
defer wg.Done()
if verbose {
cmd.Printf("Shutting down provisioner daemon %d...\n", id)
}
err := shutdownWithTimeout(provisionerDaemon, 5*time.Second)
if err != nil {
cmd.PrintErrf("Failed to shutdown provisioner daemon %d: %s\n", id, err)
return
}
err = provisionerDaemon.Close()
if err != nil {
cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err)
return
}
if verbose {
cmd.Printf("Gracefully shut down provisioner daemon %d\n", id)
}
}()
}
wg.Wait()
cmd.Println("Waiting for WebSocket connections to close...")
_ = coderAPI.Close()
cmd.Println("Done wainting for WebSocket connections")
// Close tunnel after we no longer have in-flight connections.
if tunnel { if tunnel {
cmd.Println("Waiting for tunnel to close...") cmd.Println("Waiting for tunnel to close...")
closeTunnel() closeTunnel()
<-devTunnelErrChan <-devTunnelErr
cmd.Println("Done waiting for tunnel")
} }
// Ensures a last report can be sent before exit! // Ensures a last report can be sent before exit!
options.Telemetry.Close() options.Telemetry.Close()
cmd.Println("Waiting for WebSocket connections to close...")
shutdownConns() // Trigger context cancellation for any remaining services.
coderAPI.Close() cancel()
return nil
return exitErr
}, },
} }
@ -602,16 +672,37 @@ func server() *cobra.Command {
return root return root
} }
func shutdownWithTimeout(s interface{ Shutdown(context.Context) error }, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return s.Shutdown(ctx)
}
// nolint:revive // nolint:revive
func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API, func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
logger slog.Logger, cacheDir string, errChan chan error, dev bool) (*provisionerd.Server, error) { logger slog.Logger, cacheDir string, errCh chan error, dev bool,
err := os.MkdirAll(cacheDir, 0700) ) (srv *provisionerd.Server, err error) {
ctx, cancel := context.WithCancel(ctx)
defer func() {
if err != nil {
cancel()
}
}()
err = os.MkdirAll(cacheDir, 0o700)
if err != nil { if err != nil {
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err) return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
} }
terraformClient, terraformServer := provisionersdk.TransportPipe() terraformClient, terraformServer := provisionersdk.TransportPipe()
go func() { go func() {
<-ctx.Done()
_ = terraformClient.Close()
_ = terraformServer.Close()
}()
go func() {
defer cancel()
err := terraform.Serve(ctx, &terraform.ServeOptions{ err := terraform.Serve(ctx, &terraform.ServeOptions{
ServeOptions: &provisionersdk.ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{
Listener: terraformServer, Listener: terraformServer,
@ -620,7 +711,10 @@ func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
Logger: logger, Logger: logger,
}) })
if err != nil && !xerrors.Is(err, context.Canceled) { if err != nil && !xerrors.Is(err, context.Canceled) {
errChan <- err select {
case errCh <- err:
default:
}
} }
}() }()
@ -636,9 +730,19 @@ func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
if dev { if dev {
echoClient, echoServer := provisionersdk.TransportPipe() echoClient, echoServer := provisionersdk.TransportPipe()
go func() { go func() {
<-ctx.Done()
_ = echoClient.Close()
_ = echoServer.Close()
}()
go func() {
defer cancel()
err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer}) err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer})
if err != nil { if err != nil {
errChan <- err select {
case errCh <- err:
default:
}
} }
}() }()
provisioners[string(database.ProvisionerTypeEcho)] = proto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)) provisioners[string(database.ProvisionerTypeEcho)] = proto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient))

View File

@ -17,7 +17,6 @@ import (
"net/url" "net/url"
"os" "os"
"runtime" "runtime"
"strings"
"testing" "testing"
"time" "time"
@ -30,6 +29,7 @@ import (
"github.com/coder/coder/coderd/database/postgres" "github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/pty/ptytest"
) )
// This cannot be ran in parallel because it uses a signal. // This cannot be ran in parallel because it uses a signal.
@ -45,13 +45,14 @@ func TestServer(t *testing.T) {
defer closeFunc() defer closeFunc()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, cfg := clitest.New(t, root, cfg := clitest.New(t,
"server", "server",
"--address", ":0", "--address", ":0",
"--postgres-url", connectionURL, "--postgres-url", connectionURL,
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
errC := make(chan error) errC := make(chan error, 1)
go func() { go func() {
errC <- root.ExecuteContext(ctx) errC <- root.ExecuteContext(ctx)
}() }()
@ -80,12 +81,17 @@ func TestServer(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t, root, cfg := clitest.New(t,
"server", "server",
"--address", ":0", "--address", ":0",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
errC := make(chan error) pty := ptytest.New(t)
root.SetOutput(pty.Output())
root.SetErr(pty.Output())
errC := make(chan error, 1)
go func() { go func() {
errC <- root.ExecuteContext(ctx) errC <- root.ExecuteContext(ctx)
}() }()
@ -99,11 +105,12 @@ func TestServer(t *testing.T) {
t.Run("BuiltinPostgresURL", func(t *testing.T) { t.Run("BuiltinPostgresURL", func(t *testing.T) {
t.Parallel() t.Parallel()
root, _ := clitest.New(t, "server", "postgres-builtin-url") root, _ := clitest.New(t, "server", "postgres-builtin-url")
var buf strings.Builder pty := ptytest.New(t)
root.SetOutput(&buf) root.SetOutput(pty.Output())
err := root.Execute() err := root.Execute()
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, buf.String(), "psql")
pty.ExpectMatch("psql")
}) })
t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) { t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) {
@ -118,9 +125,9 @@ func TestServer(t *testing.T) {
"--access-url", "http://1.2.3.4:3000/", "--access-url", "http://1.2.3.4:3000/",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
var buf strings.Builder buf := newThreadSafeBuffer()
errC := make(chan error) root.SetOutput(buf)
root.SetOutput(&buf) errC := make(chan error, 1)
go func() { go func() {
errC <- root.ExecuteContext(ctx) errC <- root.ExecuteContext(ctx)
}() }()
@ -142,6 +149,7 @@ func TestServer(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, _ := clitest.New(t, root, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
@ -157,6 +165,7 @@ func TestServer(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, _ := clitest.New(t, root, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
@ -172,6 +181,7 @@ func TestServer(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, _ := clitest.New(t, root, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
@ -197,7 +207,7 @@ func TestServer(t *testing.T) {
"--tls-key-file", keyPath, "--tls-key-file", keyPath,
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
errC := make(chan error) errC := make(chan error, 1)
go func() { go func() {
errC <- root.ExecuteContext(ctx) errC <- root.ExecuteContext(ctx)
}() }()
@ -236,6 +246,7 @@ func TestServer(t *testing.T) {
} }
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, cfg := clitest.New(t, root, cfg := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
@ -243,7 +254,7 @@ func TestServer(t *testing.T) {
"--provisioner-daemons", "1", "--provisioner-daemons", "1",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
serverErr := make(chan error) serverErr := make(chan error, 1)
go func() { go func() {
serverErr <- root.ExecuteContext(ctx) serverErr <- root.ExecuteContext(ctx)
}() }()
@ -259,12 +270,13 @@ func TestServer(t *testing.T) {
// We cannot send more signals here, because it's possible Coder // We cannot send more signals here, because it's possible Coder
// has already exited, which could cause the test to fail due to interrupt. // has already exited, which could cause the test to fail due to interrupt.
err = <-serverErr err = <-serverErr
require.NoError(t, err) require.ErrorIs(t, err, context.Canceled)
}) })
t.Run("TracerNoLeak", func(t *testing.T) { t.Run("TracerNoLeak", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc() defer cancelFunc()
root, _ := clitest.New(t, root, _ := clitest.New(t,
"server", "server",
"--in-memory", "--in-memory",
@ -272,7 +284,7 @@ func TestServer(t *testing.T) {
"--trace=true", "--trace=true",
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
errC := make(chan error) errC := make(chan error, 1)
go func() { go func() {
errC <- root.ExecuteContext(ctx) errC <- root.ExecuteContext(ctx)
}() }()
@ -310,7 +322,7 @@ func TestServer(t *testing.T) {
"--telemetry-url", server.URL, "--telemetry-url", server.URL,
"--cache-dir", t.TempDir(), "--cache-dir", t.TempDir(),
) )
errC := make(chan error) errC := make(chan error, 1)
go func() { go func() {
errC <- root.ExecuteContext(ctx) errC <- root.ExecuteContext(ctx)
}() }()

View File

@ -54,15 +54,27 @@ func (e *Executor) WithStatsChannel(ch chan<- Stats) *Executor {
// its channel is closed. // its channel is closed.
func (e *Executor) Run() { func (e *Executor) Run() {
go func() { go func() {
for t := range e.tick { for {
stats := e.runOnce(t) select {
if stats.Error != nil { case <-e.ctx.Done():
e.log.Error(e.ctx, "error running once", slog.Error(stats.Error)) return
case t, ok := <-e.tick:
if !ok {
return
}
stats := e.runOnce(t)
if stats.Error != nil {
e.log.Error(e.ctx, "error running once", slog.Error(stats.Error))
}
if e.statsCh != nil {
select {
case <-e.ctx.Done():
return
case e.statsCh <- stats:
}
}
e.log.Debug(e.ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions))
} }
if e.statsCh != nil {
e.statsCh <- stats
}
e.log.Debug(e.ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions))
} }
}() }()
} }

View File

@ -108,7 +108,7 @@ allowed_ip=%s/128`,
return nil, nil, xerrors.Errorf("wireguard device listen: %w", err) return nil, nil, xerrors.Errorf("wireguard device listen: %w", err)
} }
ch := make(chan error) ch := make(chan error, 1)
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@ -38,6 +38,10 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser
if err != nil { if err != nil {
return xerrors.Errorf("create yamux: %w", err) return xerrors.Errorf("create yamux: %w", err)
} }
go func() {
<-ctx.Done()
_ = stdio.Close()
}()
options.Listener = stdio options.Listener = stdio
} }