mirror of
https://github.com/coder/coder.git
synced 2025-03-14 10:09:57 +00:00
fix: make handleManifest always signal dependents (#13141)
Fixes #13139 Using a bare channel to signal dependent goroutines means that we can only signal success, not failure, which leads to deadlock if we fail in a way that doesn't cause the whole `apiConnRoutineManager` to tear down routines. Instead, we use a new object called a `checkpoint` that signals success or failure, so that dependent routines get unblocked if the routine they depend on fails.
This commit is contained in:
@ -807,23 +807,21 @@ func (a *agent) run() (retErr error) {
|
||||
// coordination <--------------------------+
|
||||
// derp map subscriber <----------------+
|
||||
// stats report loop <---------------+
|
||||
networkOK := make(chan struct{})
|
||||
manifestOK := make(chan struct{})
|
||||
networkOK := newCheckpoint(a.logger)
|
||||
manifestOK := newCheckpoint(a.logger)
|
||||
|
||||
connMan.start("handle manifest", gracefulShutdownBehaviorStop, a.handleManifest(manifestOK))
|
||||
|
||||
connMan.start("app health reporter", gracefulShutdownBehaviorStop,
|
||||
func(ctx context.Context, conn drpc.Conn) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-manifestOK:
|
||||
manifest := a.manifest.Load()
|
||||
NewWorkspaceAppHealthReporter(
|
||||
a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)),
|
||||
)(ctx)
|
||||
return nil
|
||||
if err := manifestOK.wait(ctx); err != nil {
|
||||
return xerrors.Errorf("no manifest: %w", err)
|
||||
}
|
||||
manifest := a.manifest.Load()
|
||||
NewWorkspaceAppHealthReporter(
|
||||
a.logger, manifest.Apps, agentsdk.AppHealthPoster(proto.NewDRPCAgentClient(conn)),
|
||||
)(ctx)
|
||||
return nil
|
||||
})
|
||||
|
||||
connMan.start("create or update network", gracefulShutdownBehaviorStop,
|
||||
@ -831,10 +829,8 @@ func (a *agent) run() (retErr error) {
|
||||
|
||||
connMan.start("coordination", gracefulShutdownBehaviorStop,
|
||||
func(ctx context.Context, conn drpc.Conn) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-networkOK:
|
||||
if err := networkOK.wait(ctx); err != nil {
|
||||
return xerrors.Errorf("no network: %w", err)
|
||||
}
|
||||
return a.runCoordinator(ctx, conn, a.network)
|
||||
},
|
||||
@ -842,10 +838,8 @@ func (a *agent) run() (retErr error) {
|
||||
|
||||
connMan.start("derp map subscriber", gracefulShutdownBehaviorStop,
|
||||
func(ctx context.Context, conn drpc.Conn) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-networkOK:
|
||||
if err := networkOK.wait(ctx); err != nil {
|
||||
return xerrors.Errorf("no network: %w", err)
|
||||
}
|
||||
return a.runDERPMapSubscriber(ctx, conn, a.network)
|
||||
})
|
||||
@ -853,10 +847,8 @@ func (a *agent) run() (retErr error) {
|
||||
connMan.start("fetch service banner loop", gracefulShutdownBehaviorStop, a.fetchServiceBannerLoop)
|
||||
|
||||
connMan.start("stats report loop", gracefulShutdownBehaviorStop, func(ctx context.Context, conn drpc.Conn) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-networkOK:
|
||||
if err := networkOK.wait(ctx); err != nil {
|
||||
return xerrors.Errorf("no network: %w", err)
|
||||
}
|
||||
return a.statsReporter.reportLoop(ctx, proto.NewDRPCAgentClient(conn))
|
||||
})
|
||||
@ -865,8 +857,17 @@ func (a *agent) run() (retErr error) {
|
||||
}
|
||||
|
||||
// handleManifest returns a function that fetches and processes the manifest
|
||||
func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Context, conn drpc.Conn) error {
|
||||
func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context, conn drpc.Conn) error {
|
||||
return func(ctx context.Context, conn drpc.Conn) error {
|
||||
var (
|
||||
sentResult = false
|
||||
err error
|
||||
)
|
||||
defer func() {
|
||||
if !sentResult {
|
||||
manifestOK.complete(err)
|
||||
}
|
||||
}()
|
||||
aAPI := proto.NewDRPCAgentClient(conn)
|
||||
mp, err := aAPI.GetManifest(ctx, &proto.GetManifestRequest{})
|
||||
if err != nil {
|
||||
@ -907,7 +908,8 @@ func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Cont
|
||||
}
|
||||
|
||||
oldManifest := a.manifest.Swap(&manifest)
|
||||
close(manifestOK)
|
||||
manifestOK.complete(nil)
|
||||
sentResult = true
|
||||
|
||||
// The startup script should only execute on the first run!
|
||||
if oldManifest == nil {
|
||||
@ -968,14 +970,15 @@ func (a *agent) handleManifest(manifestOK chan<- struct{}) func(ctx context.Cont
|
||||
|
||||
// createOrUpdateNetwork waits for the manifest to be set using manifestOK, then creates or updates
|
||||
// the tailnet using the information in the manifest
|
||||
func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan<- struct{}) func(context.Context, drpc.Conn) error {
|
||||
return func(ctx context.Context, _ drpc.Conn) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-manifestOK:
|
||||
func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(context.Context, drpc.Conn) error {
|
||||
return func(ctx context.Context, _ drpc.Conn) (retErr error) {
|
||||
if err := manifestOK.wait(ctx); err != nil {
|
||||
return xerrors.Errorf("no manifest: %w", err)
|
||||
}
|
||||
var err error
|
||||
defer func() {
|
||||
networkOK.complete(retErr)
|
||||
}()
|
||||
manifest := a.manifest.Load()
|
||||
a.closeMutex.Lock()
|
||||
network := a.network
|
||||
@ -1011,7 +1014,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK <-chan struct{}, networkOK chan
|
||||
network.SetDERPForceWebSockets(manifest.DERPForceWebSockets)
|
||||
network.SetBlockEndpoints(manifest.DisableDirectConnections)
|
||||
}
|
||||
close(networkOK)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
51
agent/checkpoint.go
Normal file
51
agent/checkpoint.go
Normal file
@ -0,0 +1,51 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// checkpoint allows a goroutine to communicate when it is OK to proceed beyond some async condition
|
||||
// to other dependent goroutines.
|
||||
type checkpoint struct {
|
||||
logger slog.Logger
|
||||
mu sync.Mutex
|
||||
called bool
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
// complete the checkpoint. Pass nil to indicate the checkpoint was ok. It is an error to call this
|
||||
// more than once.
|
||||
func (c *checkpoint) complete(err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.called {
|
||||
b := make([]byte, 2048)
|
||||
n := runtime.Stack(b, false)
|
||||
c.logger.Critical(context.Background(), "checkpoint complete called more than once", slog.F("stacktrace", b[:n]))
|
||||
return
|
||||
}
|
||||
c.called = true
|
||||
c.err = err
|
||||
close(c.done)
|
||||
}
|
||||
|
||||
func (c *checkpoint) wait(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return c.err
|
||||
}
|
||||
}
|
||||
|
||||
func newCheckpoint(logger slog.Logger) *checkpoint {
|
||||
return &checkpoint{
|
||||
logger: logger,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
49
agent/checkpoint_internal_test.go
Normal file
49
agent/checkpoint_internal_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestCheckpoint_CompleteWait(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
uut := newCheckpoint(logger)
|
||||
err := xerrors.New("test")
|
||||
uut.complete(err)
|
||||
got := uut.wait(ctx)
|
||||
require.Equal(t, err, got)
|
||||
}
|
||||
|
||||
func TestCheckpoint_CompleteTwice(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
uut := newCheckpoint(logger)
|
||||
err := xerrors.New("test")
|
||||
uut.complete(err)
|
||||
uut.complete(nil) // drops CRITICAL log
|
||||
got := uut.wait(ctx)
|
||||
require.Equal(t, err, got)
|
||||
}
|
||||
|
||||
func TestCheckpoint_WaitComplete(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
uut := newCheckpoint(logger)
|
||||
err := xerrors.New("test")
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- uut.wait(ctx)
|
||||
}()
|
||||
uut.complete(err)
|
||||
got := testutil.RequireRecvCtx(ctx, t, errCh)
|
||||
require.Equal(t, err, got)
|
||||
}
|
Reference in New Issue
Block a user