mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
fix: Improve coder server
shutdown procedure (#3246)
* fix: Improve `coder server` shutdown procedure This commit improves the `coder server` shutdown procedure so that all triggers for shutdown do so in a graceful way without skipping any steps. We also improve cancellation and shutdown of services by ensuring resources are cleaned up at the end. Notable changes: - We wrap `cmd.Context()` to allow us to control cancellation better - We attempt graceful shutdown of the http server (`server.Shutdown`) because it's less abrupt (compared to `shutdownConns`) - All exit paths share the same shutdown procedure (except for early exit) - `provisionerd`s are now shutdown concurrently instead of one at a time, the also now get a new context for shutdown because `cmd.Context()` may be cancelled - Resources created by `newProvisionerDaemon` are cleaned up - Lifecycle `Executor` exits its goroutine on context cancellation Fixes #3245
This commit is contained in:
committed by
GitHub
parent
bb05b1f749
commit
d27076cac7
314
cli/server.go
314
cli/server.go
@ -20,6 +20,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-systemd/daemon"
|
"github.com/coreos/go-systemd/daemon"
|
||||||
@ -111,26 +112,34 @@ func server() *cobra.Command {
|
|||||||
logger = logger.Leveled(slog.LevelDebug)
|
logger = logger.Leveled(slog.LevelDebug)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Main command context for managing cancellation
|
||||||
|
// of running services.
|
||||||
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Clean up idle connections at the end, e.g.
|
||||||
|
// embedded-postgres can leave an idle connection
|
||||||
|
// which is caught by goleaks.
|
||||||
|
defer http.DefaultClient.CloseIdleConnections()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tracerProvider *sdktrace.TracerProvider
|
tracerProvider *sdktrace.TracerProvider
|
||||||
err error
|
err error
|
||||||
sqlDriver = "postgres"
|
sqlDriver = "postgres"
|
||||||
)
|
)
|
||||||
if trace {
|
if trace {
|
||||||
tracerProvider, err = tracing.TracerProvider(cmd.Context(), "coderd")
|
tracerProvider, err = tracing.TracerProvider(ctx, "coderd")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(cmd.Context(), "failed to start telemetry exporter", slog.Error(err))
|
logger.Warn(ctx, "failed to start telemetry exporter", slog.Error(err))
|
||||||
} else {
|
} else {
|
||||||
|
// allow time for traces to flush even if command context is canceled
|
||||||
defer func() {
|
defer func() {
|
||||||
// allow time for traces to flush even if command context is canceled
|
_ = shutdownWithTimeout(tracerProvider, 5*time.Second)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
_ = tracerProvider.Shutdown(ctx)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
d, err := tracing.PostgresDriver(tracerProvider, "coderd.database")
|
d, err := tracing.PostgresDriver(tracerProvider, "coderd.database")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(cmd.Context(), "failed to start postgres tracing driver", slog.Error(err))
|
logger.Warn(ctx, "failed to start postgres tracing driver", slog.Error(err))
|
||||||
} else {
|
} else {
|
||||||
sqlDriver = d
|
sqlDriver = d
|
||||||
}
|
}
|
||||||
@ -143,14 +152,16 @@ func server() *cobra.Command {
|
|||||||
if !inMemoryDatabase && postgresURL == "" {
|
if !inMemoryDatabase && postgresURL == "" {
|
||||||
var closeFunc func() error
|
var closeFunc func() error
|
||||||
cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath())
|
cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath())
|
||||||
postgresURL, closeFunc, err = startBuiltinPostgres(cmd.Context(), config, logger)
|
postgresURL, closeFunc, err = startBuiltinPostgres(ctx, config, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
builtinPostgres = true
|
builtinPostgres = true
|
||||||
defer func() {
|
defer func() {
|
||||||
|
cmd.Printf("Stopping built-in PostgreSQL...\n")
|
||||||
// Gracefully shut PostgreSQL down!
|
// Gracefully shut PostgreSQL down!
|
||||||
_ = closeFunc()
|
_ = closeFunc()
|
||||||
|
cmd.Printf("Stopped built-in PostgreSQL\n")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,9 +200,9 @@ func server() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ctxTunnel, closeTunnel = context.WithCancel(cmd.Context())
|
ctxTunnel, closeTunnel = context.WithCancel(ctx)
|
||||||
devTunnel = (*devtunnel.Tunnel)(nil)
|
devTunnel *devtunnel.Tunnel
|
||||||
devTunnelErrChan = make(<-chan error, 1)
|
devTunnelErr <-chan error
|
||||||
)
|
)
|
||||||
defer closeTunnel()
|
defer closeTunnel()
|
||||||
|
|
||||||
@ -199,7 +210,7 @@ func server() *cobra.Command {
|
|||||||
// needs to be changed to use the tunnel.
|
// needs to be changed to use the tunnel.
|
||||||
if tunnel {
|
if tunnel {
|
||||||
cmd.Printf("Opening tunnel so workspaces can connect to your deployment\n")
|
cmd.Printf("Opening tunnel so workspaces can connect to your deployment\n")
|
||||||
devTunnel, devTunnelErrChan, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
|
devTunnel, devTunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create tunnel: %w", err)
|
return xerrors.Errorf("create tunnel: %w", err)
|
||||||
}
|
}
|
||||||
@ -207,7 +218,7 @@ func server() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Warn the user if the access URL appears to be a loopback address.
|
// Warn the user if the access URL appears to be a loopback address.
|
||||||
isLocal, err := isLocalURL(cmd.Context(), accessURL)
|
isLocal, err := isLocalURL(ctx, accessURL)
|
||||||
if isLocal || err != nil {
|
if isLocal || err != nil {
|
||||||
reason := "could not be resolved"
|
reason := "could not be resolved"
|
||||||
if isLocal {
|
if isLocal {
|
||||||
@ -224,7 +235,7 @@ func server() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Used for zero-trust instance identity with Google Cloud.
|
// Used for zero-trust instance identity with Google Cloud.
|
||||||
googleTokenValidator, err := idtoken.NewValidator(cmd.Context(), option.WithoutAuthentication())
|
googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -241,6 +252,7 @@ func server() *cobra.Command {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create turn server: %w", err)
|
return xerrors.Errorf("create turn server: %w", err)
|
||||||
}
|
}
|
||||||
|
defer turnServer.Close()
|
||||||
|
|
||||||
iceServers := make([]webrtc.ICEServer, 0)
|
iceServers := make([]webrtc.ICEServer, 0)
|
||||||
for _, stunServer := range stunServers {
|
for _, stunServer := range stunServers {
|
||||||
@ -278,6 +290,8 @@ func server() *cobra.Command {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("dial postgres: %w", err)
|
return xerrors.Errorf("dial postgres: %w", err)
|
||||||
}
|
}
|
||||||
|
defer sqlDB.Close()
|
||||||
|
|
||||||
err = sqlDB.Ping()
|
err = sqlDB.Ping()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("ping postgres: %w", err)
|
return xerrors.Errorf("ping postgres: %w", err)
|
||||||
@ -287,13 +301,14 @@ func server() *cobra.Command {
|
|||||||
return xerrors.Errorf("migrate up: %w", err)
|
return xerrors.Errorf("migrate up: %w", err)
|
||||||
}
|
}
|
||||||
options.Database = database.New(sqlDB)
|
options.Database = database.New(sqlDB)
|
||||||
options.Pubsub, err = database.NewPubsub(cmd.Context(), sqlDB, postgresURL)
|
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, postgresURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create pubsub: %w", err)
|
return xerrors.Errorf("create pubsub: %w", err)
|
||||||
}
|
}
|
||||||
|
defer options.Pubsub.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
deploymentID, err := options.Database.GetDeploymentID(cmd.Context())
|
deploymentID, err := options.Database.GetDeploymentID(ctx)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
@ -302,7 +317,7 @@ func server() *cobra.Command {
|
|||||||
}
|
}
|
||||||
if deploymentID == "" {
|
if deploymentID == "" {
|
||||||
deploymentID = uuid.NewString()
|
deploymentID = uuid.NewString()
|
||||||
err = options.Database.InsertDeploymentID(cmd.Context(), deploymentID)
|
err = options.Database.InsertDeploymentID(ctx, deploymentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("set deployment id: %w", err)
|
return xerrors.Errorf("set deployment id: %w", err)
|
||||||
}
|
}
|
||||||
@ -336,6 +351,8 @@ func server() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
coderAPI := coderd.New(options)
|
coderAPI := coderd.New(options)
|
||||||
|
defer coderAPI.Close()
|
||||||
|
|
||||||
client := codersdk.New(localURL)
|
client := codersdk.New(localURL)
|
||||||
if tlsEnable {
|
if tlsEnable {
|
||||||
// Secure transport isn't needed for locally communicating!
|
// Secure transport isn't needed for locally communicating!
|
||||||
@ -351,64 +368,75 @@ func server() *cobra.Command {
|
|||||||
_ = pprof.Handler
|
_ = pprof.Handler
|
||||||
if pprofEnabled {
|
if pprofEnabled {
|
||||||
//nolint:revive
|
//nolint:revive
|
||||||
defer serveHandler(cmd.Context(), logger, nil, pprofAddress, "pprof")()
|
defer serveHandler(ctx, logger, nil, pprofAddress, "pprof")()
|
||||||
}
|
}
|
||||||
if promEnabled {
|
if promEnabled {
|
||||||
//nolint:revive
|
//nolint:revive
|
||||||
defer serveHandler(cmd.Context(), logger, promhttp.Handler(), promAddress, "prometheus")()
|
defer serveHandler(ctx, logger, promhttp.Handler(), promAddress, "prometheus")()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Since errCh only has one buffered slot, all routines
|
||||||
|
// sending on it must be wrapped in a select/default to
|
||||||
|
// avoid leaving dangling goroutines waiting for the
|
||||||
|
// channel to be consumed.
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
provisionerDaemons := make([]*provisionerd.Server, 0)
|
provisionerDaemons := make([]*provisionerd.Server, 0)
|
||||||
|
defer func() {
|
||||||
|
// We have no graceful shutdown of provisionerDaemons
|
||||||
|
// here because that's handled at the end of main, this
|
||||||
|
// is here in case the program exits early.
|
||||||
|
for _, daemon := range provisionerDaemons {
|
||||||
|
_ = daemon.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
for i := 0; uint8(i) < provisionerDaemonCount; i++ {
|
for i := 0; uint8(i) < provisionerDaemonCount; i++ {
|
||||||
daemonClose, err := newProvisionerDaemon(cmd.Context(), coderAPI, logger, cacheDir, errCh, false)
|
daemon, err := newProvisionerDaemon(ctx, coderAPI, logger, cacheDir, errCh, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create provisioner daemon: %w", err)
|
return xerrors.Errorf("create provisioner daemon: %w", err)
|
||||||
}
|
}
|
||||||
provisionerDaemons = append(provisionerDaemons, daemonClose)
|
provisionerDaemons = append(provisionerDaemons, daemon)
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownConnsCtx, shutdownConns := context.WithCancel(ctx)
|
||||||
|
defer shutdownConns()
|
||||||
|
server := &http.Server{
|
||||||
|
// These errors are typically noise like "TLS: EOF". Vault does similar:
|
||||||
|
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
|
||||||
|
ErrorLog: log.New(io.Discard, "", 0),
|
||||||
|
Handler: coderAPI.Handler,
|
||||||
|
BaseContext: func(_ net.Listener) context.Context {
|
||||||
|
return shutdownConnsCtx
|
||||||
|
},
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
for _, provisionerDaemon := range provisionerDaemons {
|
_ = shutdownWithTimeout(server, 5*time.Second)
|
||||||
_ = provisionerDaemon.Close()
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
shutdownConnsCtx, shutdownConns := context.WithCancel(cmd.Context())
|
eg := errgroup.Group{}
|
||||||
defer shutdownConns()
|
eg.Go(func() error {
|
||||||
go func() {
|
// Make sure to close the tunnel listener if we exit so the
|
||||||
server := http.Server{
|
// errgroup doesn't wait forever!
|
||||||
// These errors are typically noise like "TLS: EOF". Vault does similar:
|
|
||||||
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
|
|
||||||
ErrorLog: log.New(io.Discard, "", 0),
|
|
||||||
Handler: coderAPI.Handler,
|
|
||||||
BaseContext: func(_ net.Listener) context.Context {
|
|
||||||
return shutdownConnsCtx
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
wg := errgroup.Group{}
|
|
||||||
wg.Go(func() error {
|
|
||||||
// Make sure to close the tunnel listener if we exit so the
|
|
||||||
// errgroup doesn't wait forever!
|
|
||||||
if tunnel {
|
|
||||||
defer devTunnel.Listener.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return server.Serve(listener)
|
|
||||||
})
|
|
||||||
|
|
||||||
if tunnel {
|
if tunnel {
|
||||||
wg.Go(func() error {
|
defer devTunnel.Listener.Close()
|
||||||
defer listener.Close()
|
|
||||||
|
|
||||||
return server.Serve(devTunnel.Listener)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
errCh <- wg.Wait()
|
return server.Serve(listener)
|
||||||
|
})
|
||||||
|
if tunnel {
|
||||||
|
eg.Go(func() error {
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
return server.Serve(devTunnel.Listener)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case errCh <- eg.Wait():
|
||||||
|
default:
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hasFirstUser, err := client.HasFirstUser(cmd.Context())
|
hasFirstUser, err := client.HasFirstUser(ctx)
|
||||||
if !hasFirstUser && err == nil {
|
if !hasFirstUser && err == nil {
|
||||||
cmd.Println()
|
cmd.Println()
|
||||||
cmd.Println("Get started by creating the first user (in a new terminal):")
|
cmd.Println("Get started by creating the first user (in a new terminal):")
|
||||||
@ -425,75 +453,117 @@ func server() *cobra.Command {
|
|||||||
|
|
||||||
autobuildPoller := time.NewTicker(autobuildPollInterval)
|
autobuildPoller := time.NewTicker(autobuildPollInterval)
|
||||||
defer autobuildPoller.Stop()
|
defer autobuildPoller.Stop()
|
||||||
autobuildExecutor := executor.New(cmd.Context(), options.Database, logger, autobuildPoller.C)
|
autobuildExecutor := executor.New(ctx, options.Database, logger, autobuildPoller.C)
|
||||||
autobuildExecutor.Run()
|
autobuildExecutor.Run()
|
||||||
|
|
||||||
|
// This is helpful for tests, but can be silently ignored.
|
||||||
|
// Coder may be ran as users that don't have permission to write in the homedir,
|
||||||
|
// such as via the systemd service.
|
||||||
|
_ = config.URL().Write(client.URL.String())
|
||||||
|
|
||||||
// Because the graceful shutdown includes cleaning up workspaces in dev mode, we're
|
// Because the graceful shutdown includes cleaning up workspaces in dev mode, we're
|
||||||
// going to make it harder to accidentally skip the graceful shutdown by hitting ctrl+c
|
// going to make it harder to accidentally skip the graceful shutdown by hitting ctrl+c
|
||||||
// two or more times. So the stopChan is unlimited in size and we don't call
|
// two or more times. So the stopChan is unlimited in size and we don't call
|
||||||
// signal.Stop() until graceful shutdown finished--this means we swallow additional
|
// signal.Stop() until graceful shutdown finished--this means we swallow additional
|
||||||
// SIGINT after the first. To get out of a graceful shutdown, the user can send SIGQUIT
|
// SIGINT after the first. To get out of a graceful shutdown, the user can send SIGQUIT
|
||||||
// with ctrl+\ or SIGTERM with `kill`.
|
// with ctrl+\ or SIGTERM with `kill`.
|
||||||
stopChan := make(chan os.Signal, 1)
|
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
defer signal.Stop(stopChan)
|
defer stop()
|
||||||
signal.Notify(stopChan, os.Interrupt)
|
|
||||||
|
|
||||||
// This is helpful for tests, but can be silently ignored.
|
|
||||||
// Coder may be ran as users that don't have permission to write in the homedir,
|
|
||||||
// such as via the systemd service.
|
|
||||||
_ = config.URL().Write(client.URL.String())
|
|
||||||
|
|
||||||
|
// Currently there is no way to ask the server to shut
|
||||||
|
// itself down, so any exit signal will result in a non-zero
|
||||||
|
// exit of the server.
|
||||||
|
var exitErr error
|
||||||
select {
|
select {
|
||||||
case <-cmd.Context().Done():
|
case <-ctx.Done():
|
||||||
coderAPI.Close()
|
exitErr = ctx.Err()
|
||||||
return cmd.Context().Err()
|
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
|
||||||
case err := <-devTunnelErrChan:
|
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit",
|
||||||
if err != nil {
|
))
|
||||||
return err
|
case exitErr = <-devTunnelErr:
|
||||||
|
if exitErr == nil {
|
||||||
|
exitErr = xerrors.New("dev tunnel closed unexpectedly")
|
||||||
}
|
}
|
||||||
case err := <-errCh:
|
case exitErr = <-errCh:
|
||||||
shutdownConns()
|
|
||||||
coderAPI.Close()
|
|
||||||
return err
|
|
||||||
case <-stopChan:
|
|
||||||
}
|
}
|
||||||
|
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
|
||||||
|
cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin clean shut down stage, we try to shut down services
|
||||||
|
// gracefully in an order that gives the best experience.
|
||||||
|
// This procedure should not differ greatly from the order
|
||||||
|
// of `defer`s in this function, but allows us to inform
|
||||||
|
// the user about what's going on and handle errors more
|
||||||
|
// explicitly.
|
||||||
|
|
||||||
_, err = daemon.SdNotify(false, daemon.SdNotifyStopping)
|
_, err = daemon.SdNotify(false, daemon.SdNotifyStopping)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("notify systemd: %w", err)
|
cmd.Printf("Notify systemd failed: %s", err)
|
||||||
}
|
|
||||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
|
|
||||||
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit"))
|
|
||||||
|
|
||||||
for _, provisionerDaemon := range provisionerDaemons {
|
|
||||||
if verbose {
|
|
||||||
cmd.Println("Shutting down provisioner daemon...")
|
|
||||||
}
|
|
||||||
err = provisionerDaemon.Shutdown(cmd.Context())
|
|
||||||
if err != nil {
|
|
||||||
cmd.PrintErrf("Failed to shutdown provisioner daemon: %s\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = provisionerDaemon.Close()
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("close provisioner daemon: %w", err)
|
|
||||||
}
|
|
||||||
if verbose {
|
|
||||||
cmd.Println("Gracefully shut down provisioner daemon!")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop accepting new connections without interrupting
|
||||||
|
// in-flight requests, give in-flight requests 5 seconds to
|
||||||
|
// complete.
|
||||||
|
cmd.Println("Shutting down API server...")
|
||||||
|
err = shutdownWithTimeout(server, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
cmd.Printf("API server shutdown took longer than 5s: %s", err)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("Gracefully shut down API server\n")
|
||||||
|
}
|
||||||
|
// Cancel any remaining in-flight requests.
|
||||||
|
shutdownConns()
|
||||||
|
|
||||||
|
// Shut down provisioners before waiting for WebSockets
|
||||||
|
// connections to close.
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i, provisionerDaemon := range provisionerDaemons {
|
||||||
|
id := i + 1
|
||||||
|
provisionerDaemon := provisionerDaemon
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
cmd.Printf("Shutting down provisioner daemon %d...\n", id)
|
||||||
|
}
|
||||||
|
err := shutdownWithTimeout(provisionerDaemon, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Failed to shutdown provisioner daemon %d: %s\n", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = provisionerDaemon.Close()
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if verbose {
|
||||||
|
cmd.Printf("Gracefully shut down provisioner daemon %d\n", id)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
cmd.Println("Waiting for WebSocket connections to close...")
|
||||||
|
_ = coderAPI.Close()
|
||||||
|
cmd.Println("Done wainting for WebSocket connections")
|
||||||
|
|
||||||
|
// Close tunnel after we no longer have in-flight connections.
|
||||||
if tunnel {
|
if tunnel {
|
||||||
cmd.Println("Waiting for tunnel to close...")
|
cmd.Println("Waiting for tunnel to close...")
|
||||||
closeTunnel()
|
closeTunnel()
|
||||||
<-devTunnelErrChan
|
<-devTunnelErr
|
||||||
|
cmd.Println("Done waiting for tunnel")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensures a last report can be sent before exit!
|
// Ensures a last report can be sent before exit!
|
||||||
options.Telemetry.Close()
|
options.Telemetry.Close()
|
||||||
cmd.Println("Waiting for WebSocket connections to close...")
|
|
||||||
shutdownConns()
|
// Trigger context cancellation for any remaining services.
|
||||||
coderAPI.Close()
|
cancel()
|
||||||
return nil
|
|
||||||
|
return exitErr
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -602,16 +672,37 @@ func server() *cobra.Command {
|
|||||||
return root
|
return root
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shutdownWithTimeout(s interface{ Shutdown(context.Context) error }, timeout time.Duration) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
return s.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:revive
|
// nolint:revive
|
||||||
func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
|
func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
|
||||||
logger slog.Logger, cacheDir string, errChan chan error, dev bool) (*provisionerd.Server, error) {
|
logger slog.Logger, cacheDir string, errCh chan error, dev bool,
|
||||||
err := os.MkdirAll(cacheDir, 0700)
|
) (srv *provisionerd.Server, err error) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = os.MkdirAll(cacheDir, 0o700)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
|
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
terraformClient, terraformServer := provisionersdk.TransportPipe()
|
terraformClient, terraformServer := provisionersdk.TransportPipe()
|
||||||
go func() {
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
_ = terraformClient.Close()
|
||||||
|
_ = terraformServer.Close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
err := terraform.Serve(ctx, &terraform.ServeOptions{
|
err := terraform.Serve(ctx, &terraform.ServeOptions{
|
||||||
ServeOptions: &provisionersdk.ServeOptions{
|
ServeOptions: &provisionersdk.ServeOptions{
|
||||||
Listener: terraformServer,
|
Listener: terraformServer,
|
||||||
@ -620,7 +711,10 @@ func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
|
|||||||
Logger: logger,
|
Logger: logger,
|
||||||
})
|
})
|
||||||
if err != nil && !xerrors.Is(err, context.Canceled) {
|
if err != nil && !xerrors.Is(err, context.Canceled) {
|
||||||
errChan <- err
|
select {
|
||||||
|
case errCh <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -636,9 +730,19 @@ func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
|
|||||||
if dev {
|
if dev {
|
||||||
echoClient, echoServer := provisionersdk.TransportPipe()
|
echoClient, echoServer := provisionersdk.TransportPipe()
|
||||||
go func() {
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
_ = echoClient.Close()
|
||||||
|
_ = echoServer.Close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer})
|
err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- err
|
select {
|
||||||
|
case errCh <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
provisioners[string(database.ProvisionerTypeEcho)] = proto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient))
|
provisioners[string(database.ProvisionerTypeEcho)] = proto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient))
|
||||||
|
@ -17,7 +17,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -30,6 +29,7 @@ import (
|
|||||||
"github.com/coder/coder/coderd/database/postgres"
|
"github.com/coder/coder/coderd/database/postgres"
|
||||||
"github.com/coder/coder/coderd/telemetry"
|
"github.com/coder/coder/coderd/telemetry"
|
||||||
"github.com/coder/coder/codersdk"
|
"github.com/coder/coder/codersdk"
|
||||||
|
"github.com/coder/coder/pty/ptytest"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This cannot be ran in parallel because it uses a signal.
|
// This cannot be ran in parallel because it uses a signal.
|
||||||
@ -45,13 +45,14 @@ func TestServer(t *testing.T) {
|
|||||||
defer closeFunc()
|
defer closeFunc()
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
root, cfg := clitest.New(t,
|
root, cfg := clitest.New(t,
|
||||||
"server",
|
"server",
|
||||||
"--address", ":0",
|
"--address", ":0",
|
||||||
"--postgres-url", connectionURL,
|
"--postgres-url", connectionURL,
|
||||||
"--cache-dir", t.TempDir(),
|
"--cache-dir", t.TempDir(),
|
||||||
)
|
)
|
||||||
errC := make(chan error)
|
errC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errC <- root.ExecuteContext(ctx)
|
errC <- root.ExecuteContext(ctx)
|
||||||
}()
|
}()
|
||||||
@ -80,12 +81,17 @@ func TestServer(t *testing.T) {
|
|||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
root, cfg := clitest.New(t,
|
root, cfg := clitest.New(t,
|
||||||
"server",
|
"server",
|
||||||
"--address", ":0",
|
"--address", ":0",
|
||||||
"--cache-dir", t.TempDir(),
|
"--cache-dir", t.TempDir(),
|
||||||
)
|
)
|
||||||
errC := make(chan error)
|
pty := ptytest.New(t)
|
||||||
|
root.SetOutput(pty.Output())
|
||||||
|
root.SetErr(pty.Output())
|
||||||
|
errC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errC <- root.ExecuteContext(ctx)
|
errC <- root.ExecuteContext(ctx)
|
||||||
}()
|
}()
|
||||||
@ -99,11 +105,12 @@ func TestServer(t *testing.T) {
|
|||||||
t.Run("BuiltinPostgresURL", func(t *testing.T) {
|
t.Run("BuiltinPostgresURL", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
root, _ := clitest.New(t, "server", "postgres-builtin-url")
|
root, _ := clitest.New(t, "server", "postgres-builtin-url")
|
||||||
var buf strings.Builder
|
pty := ptytest.New(t)
|
||||||
root.SetOutput(&buf)
|
root.SetOutput(pty.Output())
|
||||||
err := root.Execute()
|
err := root.Execute()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, buf.String(), "psql")
|
|
||||||
|
pty.ExpectMatch("psql")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) {
|
t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) {
|
||||||
@ -118,9 +125,9 @@ func TestServer(t *testing.T) {
|
|||||||
"--access-url", "http://1.2.3.4:3000/",
|
"--access-url", "http://1.2.3.4:3000/",
|
||||||
"--cache-dir", t.TempDir(),
|
"--cache-dir", t.TempDir(),
|
||||||
)
|
)
|
||||||
var buf strings.Builder
|
buf := newThreadSafeBuffer()
|
||||||
errC := make(chan error)
|
root.SetOutput(buf)
|
||||||
root.SetOutput(&buf)
|
errC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errC <- root.ExecuteContext(ctx)
|
errC <- root.ExecuteContext(ctx)
|
||||||
}()
|
}()
|
||||||
@ -142,6 +149,7 @@ func TestServer(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
root, _ := clitest.New(t,
|
root, _ := clitest.New(t,
|
||||||
"server",
|
"server",
|
||||||
"--in-memory",
|
"--in-memory",
|
||||||
@ -157,6 +165,7 @@ func TestServer(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
root, _ := clitest.New(t,
|
root, _ := clitest.New(t,
|
||||||
"server",
|
"server",
|
||||||
"--in-memory",
|
"--in-memory",
|
||||||
@ -172,6 +181,7 @@ func TestServer(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
root, _ := clitest.New(t,
|
root, _ := clitest.New(t,
|
||||||
"server",
|
"server",
|
||||||
"--in-memory",
|
"--in-memory",
|
||||||
@ -197,7 +207,7 @@ func TestServer(t *testing.T) {
|
|||||||
"--tls-key-file", keyPath,
|
"--tls-key-file", keyPath,
|
||||||
"--cache-dir", t.TempDir(),
|
"--cache-dir", t.TempDir(),
|
||||||
)
|
)
|
||||||
errC := make(chan error)
|
errC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errC <- root.ExecuteContext(ctx)
|
errC <- root.ExecuteContext(ctx)
|
||||||
}()
|
}()
|
||||||
@ -236,6 +246,7 @@ func TestServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
root, cfg := clitest.New(t,
|
root, cfg := clitest.New(t,
|
||||||
"server",
|
"server",
|
||||||
"--in-memory",
|
"--in-memory",
|
||||||
@ -243,7 +254,7 @@ func TestServer(t *testing.T) {
|
|||||||
"--provisioner-daemons", "1",
|
"--provisioner-daemons", "1",
|
||||||
"--cache-dir", t.TempDir(),
|
"--cache-dir", t.TempDir(),
|
||||||
)
|
)
|
||||||
serverErr := make(chan error)
|
serverErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
serverErr <- root.ExecuteContext(ctx)
|
serverErr <- root.ExecuteContext(ctx)
|
||||||
}()
|
}()
|
||||||
@ -259,12 +270,13 @@ func TestServer(t *testing.T) {
|
|||||||
// We cannot send more signals here, because it's possible Coder
|
// We cannot send more signals here, because it's possible Coder
|
||||||
// has already exited, which could cause the test to fail due to interrupt.
|
// has already exited, which could cause the test to fail due to interrupt.
|
||||||
err = <-serverErr
|
err = <-serverErr
|
||||||
require.NoError(t, err)
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
})
|
})
|
||||||
t.Run("TracerNoLeak", func(t *testing.T) {
|
t.Run("TracerNoLeak", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
defer cancelFunc()
|
defer cancelFunc()
|
||||||
|
|
||||||
root, _ := clitest.New(t,
|
root, _ := clitest.New(t,
|
||||||
"server",
|
"server",
|
||||||
"--in-memory",
|
"--in-memory",
|
||||||
@ -272,7 +284,7 @@ func TestServer(t *testing.T) {
|
|||||||
"--trace=true",
|
"--trace=true",
|
||||||
"--cache-dir", t.TempDir(),
|
"--cache-dir", t.TempDir(),
|
||||||
)
|
)
|
||||||
errC := make(chan error)
|
errC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errC <- root.ExecuteContext(ctx)
|
errC <- root.ExecuteContext(ctx)
|
||||||
}()
|
}()
|
||||||
@ -310,7 +322,7 @@ func TestServer(t *testing.T) {
|
|||||||
"--telemetry-url", server.URL,
|
"--telemetry-url", server.URL,
|
||||||
"--cache-dir", t.TempDir(),
|
"--cache-dir", t.TempDir(),
|
||||||
)
|
)
|
||||||
errC := make(chan error)
|
errC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
errC <- root.ExecuteContext(ctx)
|
errC <- root.ExecuteContext(ctx)
|
||||||
}()
|
}()
|
||||||
|
@ -54,15 +54,27 @@ func (e *Executor) WithStatsChannel(ch chan<- Stats) *Executor {
|
|||||||
// its channel is closed.
|
// its channel is closed.
|
||||||
func (e *Executor) Run() {
|
func (e *Executor) Run() {
|
||||||
go func() {
|
go func() {
|
||||||
for t := range e.tick {
|
for {
|
||||||
stats := e.runOnce(t)
|
select {
|
||||||
if stats.Error != nil {
|
case <-e.ctx.Done():
|
||||||
e.log.Error(e.ctx, "error running once", slog.Error(stats.Error))
|
return
|
||||||
|
case t, ok := <-e.tick:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stats := e.runOnce(t)
|
||||||
|
if stats.Error != nil {
|
||||||
|
e.log.Error(e.ctx, "error running once", slog.Error(stats.Error))
|
||||||
|
}
|
||||||
|
if e.statsCh != nil {
|
||||||
|
select {
|
||||||
|
case <-e.ctx.Done():
|
||||||
|
return
|
||||||
|
case e.statsCh <- stats:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.log.Debug(e.ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions))
|
||||||
}
|
}
|
||||||
if e.statsCh != nil {
|
|
||||||
e.statsCh <- stats
|
|
||||||
}
|
|
||||||
e.log.Debug(e.ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions))
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,7 @@ allowed_ip=%s/128`,
|
|||||||
return nil, nil, xerrors.Errorf("wireguard device listen: %w", err)
|
return nil, nil, xerrors.Errorf("wireguard device listen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan error)
|
ch := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -38,6 +38,10 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("create yamux: %w", err)
|
return xerrors.Errorf("create yamux: %w", err)
|
||||||
}
|
}
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
_ = stdio.Close()
|
||||||
|
}()
|
||||||
options.Listener = stdio
|
options.Listener = stdio
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user