diff --git a/agent/agent.go b/agent/agent.go index 54d17bb929..b87cf18272 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -398,24 +398,28 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ } }() if err = a.trackConnGoroutine(func() { + var wg sync.WaitGroup for { conn, err := sshListener.Accept() if err != nil { - return + break } + wg.Add(1) closed := make(chan struct{}) - _ = a.trackConnGoroutine(func() { + go func() { select { - case <-network.Closed(): case <-closed: + case <-a.closed: + _ = conn.Close() } - _ = conn.Close() - }) - _ = a.trackConnGoroutine(func() { + wg.Done() + }() + go func() { defer close(closed) a.sshServer.HandleConn(conn) - }) + }() } + wg.Wait() }); err != nil { return nil, err } @@ -431,35 +435,47 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ }() if err = a.trackConnGoroutine(func() { logger := a.logger.Named("reconnecting-pty") - + var wg sync.WaitGroup for { conn, err := reconnectingPTYListener.Accept() if err != nil { logger.Debug(ctx, "accept pty failed", slog.Error(err)) - return - } - // This cannot use a JSON decoder, since that can - // buffer additional data that is required for the PTY. - rawLen := make([]byte, 2) - _, err = conn.Read(rawLen) - if err != nil { - continue - } - length := binary.LittleEndian.Uint16(rawLen) - data := make([]byte, length) - _, err = conn.Read(data) - if err != nil { - continue - } - var msg codersdk.WorkspaceAgentReconnectingPTYInit - err = json.Unmarshal(data, &msg) - if err != nil { - continue + break } + wg.Add(1) + closed := make(chan struct{}) go func() { + select { + case <-closed: + case <-a.closed: + _ = conn.Close() + } + wg.Done() + }() + go func() { + defer close(closed) + // This cannot use a JSON decoder, since that can + // buffer additional data that is required for the PTY. + rawLen := make([]byte, 2) + _, err = conn.Read(rawLen) + if err != nil { + return + } + length := binary.LittleEndian.Uint16(rawLen) + data := make([]byte, length) + _, err = conn.Read(data) + if err != nil { + return + } + var msg codersdk.WorkspaceAgentReconnectingPTYInit + err = json.Unmarshal(data, &msg) + if err != nil { + return + } _ = a.handleReconnectingPTY(ctx, logger, msg, conn) }() } + wg.Wait() }); err != nil { return nil, err } @@ -474,20 +490,29 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ } }() if err = a.trackConnGoroutine(func() { + var wg sync.WaitGroup for { conn, err := speedtestListener.Accept() if err != nil { a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) - return + break } - if err = a.trackConnGoroutine(func() { + wg.Add(1) + closed := make(chan struct{}) + go func() { + select { + case <-closed: + case <-a.closed: + _ = conn.Close() + } + wg.Done() + }() + go func() { + defer close(closed) _ = speedtest.ServeConn(conn) - }); err != nil { - a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) - _ = conn.Close() - return - } + }() } + wg.Wait() }); err != nil { return nil, err } @@ -511,7 +536,10 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ ErrorLog: slog.Stdlib(ctx, a.logger.Named("statistics_http_server"), slog.LevelInfo), } go func() { - <-ctx.Done() + select { + case <-ctx.Done(): + case <-a.closed: + } _ = server.Close() }() diff --git a/scaletest/reconnectingpty/run_test.go b/scaletest/reconnectingpty/run_test.go index b47cf98b03..f6f70bbf57 100644 --- a/scaletest/reconnectingpty/run_test.go +++ b/scaletest/reconnectingpty/run_test.go @@ -23,10 +23,6 @@ import ( func Test_Runner(t *testing.T) { t.Parallel() - // There's a race condition in agent/agent.go where connections - // aren't closed when the Tailnet connection is. This causes the - // goroutines to hang around and cause the test to fail. - t.Skip("TODO: fix this test") t.Run("OK", func(t *testing.T) { t.Parallel()