fix(agent): fix deadlock if closed while starting listeners (#17329)

fixes #17328

Fixes a deadlock if we close the Agent in the middle of starting listeners on the tailnet.
This commit is contained in:
Spike Curtis
2025-04-10 12:46:19 +04:00
committed by GitHub
parent 8faaa14820
commit c1816e3674
2 changed files with 78 additions and 19 deletions

View File

@ -229,13 +229,21 @@ type agent struct {
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
// to start gracefully shutting down and "hard" which is Done when it is time to close
// everything down (regardless of whether graceful shutdown completed).
gracefulCtx context.Context
gracefulCancel context.CancelFunc
hardCtx context.Context
hardCancel context.CancelFunc
closeWaitGroup sync.WaitGroup
gracefulCtx context.Context
gracefulCancel context.CancelFunc
hardCtx context.Context
hardCancel context.CancelFunc
// closeMutex protects the following:
closeMutex sync.Mutex
closeWaitGroup sync.WaitGroup
coordDisconnected chan struct{}
closing bool
// note that once the network is set to non-nil, it is never modified, as with the statsReporter. So, routines
// that run after createOrUpdateNetwork and check the networkOK checkpoint do not need to hold the lock to use them.
network *tailnet.Conn
statsReporter *statsReporter
// end fields protected by closeMutex
environmentVariables map[string]string
@ -259,9 +267,7 @@ type agent struct {
reportConnectionsMu sync.Mutex
reportConnections []*proto.ReportConnectionRequest
network *tailnet.Conn
statsReporter *statsReporter
logSender *agentsdk.LogSender
logSender *agentsdk.LogSender
prometheusRegistry *prometheus.Registry
// metrics are prometheus registered metrics that will be collected and
@ -274,6 +280,8 @@ type agent struct {
}
func (a *agent) TailnetConn() *tailnet.Conn {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
return a.network
}
@ -1205,15 +1213,15 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
}
a.closeMutex.Lock()
// Re-check if agent was closed while initializing the network.
closed := a.isClosed()
if !closed {
closing := a.closing
if !closing {
a.network = network
a.statsReporter = newStatsReporter(a.logger, network, a)
}
a.closeMutex.Unlock()
if closed {
if closing {
_ = network.Close()
return xerrors.New("agent is closed")
return xerrors.New("agent is closing")
}
} else {
// Update the wireguard IPs if the agent ID changed.
@ -1328,8 +1336,8 @@ func (*agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
func (a *agent) trackGoroutine(fn func()) error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return xerrors.New("track conn goroutine: agent is closed")
if a.closing {
return xerrors.New("track conn goroutine: agent is closing")
}
a.closeWaitGroup.Add(1)
go func() {
@ -1547,7 +1555,7 @@ func (a *agent) runCoordinator(ctx context.Context, tClient tailnetproto.DRPCTai
func (a *agent) setCoordDisconnected() chan struct{} {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
if a.closing {
return nil
}
disconnected := make(chan struct{})
@ -1772,7 +1780,10 @@ func (a *agent) HTTPDebug() http.Handler {
func (a *agent) Close() error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
network := a.network
coordDisconnected := a.coordDisconnected
a.closing = true
a.closeMutex.Unlock()
if a.isClosed() {
return nil
}
@ -1849,7 +1860,7 @@ lifecycleWaitLoop:
select {
case <-a.hardCtx.Done():
a.logger.Warn(context.Background(), "timed out waiting for Coordinator RPC disconnect")
case <-a.coordDisconnected:
case <-coordDisconnected:
a.logger.Debug(context.Background(), "coordinator RPC disconnected")
}
@ -1860,8 +1871,8 @@ lifecycleWaitLoop:
}
a.hardCancel()
if a.network != nil {
_ = a.network.Close()
if network != nil {
_ = network.Close()
}
a.closeWaitGroup.Wait()

View File

@ -68,6 +68,54 @@ func TestMain(m *testing.M) {
var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}
// TestAgent_CloseWhileStarting is a regression test for https://github.com/coder/coder/issues/17328
func TestAgent_ImmediateClose(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{
// Agent can drop errors when shutting down, and some, like the
// fasthttplistener connection closed error, are unexported.
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
AgentName: "test-agent",
WorkspaceName: "test-workspace",
WorkspaceID: uuid.New(),
}
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
client := agenttest.NewClient(t, logger.Named("agenttest"), manifest.AgentID, manifest, statsCh, coordinator)
t.Cleanup(client.Close)
options := agent.Options{
Client: client,
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: 0,
EnvironmentVariables: map[string]string{},
}
agentUnderTest := agent.New(options)
t.Cleanup(func() {
_ = agentUnderTest.Close()
})
// wait until the agent has connected and is starting to find races in the startup code
_ = testutil.RequireRecvCtx(ctx, t, client.GetStartup())
t.Log("Closing Agent")
err := agentUnderTest.Close()
require.NoError(t, err)
}
// NOTE: These tests only work when your default shell is bash for some reason.
func TestAgent_Stats_SSH(t *testing.T) {