fix: fix flake in TestWorkspaceAgentClientCoordinate_ResumeToken (#14642)

fixes #14365

I bet what's going on is that in `connectToCoordinatorAndFetchResumeToken()` we call `Coordinate()`, send a message on the `Coordinate` client and then close it in rapid succession. We don't wait around for a response from the coordinator, so dRPC is likely aborting the call `Coordinate()` in the backend because the stream is closed before it even gets a chance.

Instead of using the Coordinator to record the peer ID assigned on the API call, we can wrap the resume token provider, since we call that API _and_ wait for a response. This also affords the opportunity to directly assert we get called with the right token.
This commit is contained in:
Spike Curtis
2024-09-11 16:32:47 +04:00
committed by GitHub
parent 1b5f3418d3
commit 5bd19f8ba3
3 changed files with 44 additions and 42 deletions

View File

@ -864,6 +864,8 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
}) })
return return
} }
api.Logger.Debug(ctx, "accepted coordinate resume token for peer",
slog.F("peer_id", peerID.String()))
} }
api.WebsocketWaitMutex.Lock() api.WebsocketWaitMutex.Lock()

View File

@ -513,30 +513,42 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
require.Equal(t, "version", sdkErr.Validations[0].Field) require.Equal(t, "version", sdkErr.Validations[0].Field)
} }
type resumeTokenTestFakeCoordinator struct { type resumeTokenRecordingProvider struct {
tailnet.Coordinator tailnet.ResumeTokenProvider
t testing.TB t testing.TB
peerIDCh chan uuid.UUID generateCalls chan uuid.UUID
verifyCalls chan string
} }
var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{} var _ tailnet.ResumeTokenProvider = &resumeTokenRecordingProvider{}
func (c *resumeTokenTestFakeCoordinator) storeID(id uuid.UUID) { func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeTokenProvider) *resumeTokenRecordingProvider {
select { return &resumeTokenRecordingProvider{
case c.peerIDCh <- id: ResumeTokenProvider: underlying,
default: t: t,
c.t.Fatal("peer ID channel full") generateCalls: make(chan uuid.UUID, 1),
verifyCalls: make(chan string, 1),
} }
} }
func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error { func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
c.storeID(id) select {
return c.Coordinator.ServeClient(conn, id, agentID) case r.generateCalls <- peerID:
return r.ResumeTokenProvider.GenerateResumeToken(peerID)
default:
r.t.Error("generateCalls full")
return nil, xerrors.New("generateCalls full")
}
} }
func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) { func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) {
c.storeID(id) select {
return c.Coordinator.Coordinate(ctx, id, name, a) case r.verifyCalls <- token:
return r.ResumeTokenProvider.VerifyResumeToken(token)
default:
r.t.Error("verifyCalls full")
return uuid.Nil, xerrors.New("verifyCalls full")
}
} }
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) { func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
@ -546,15 +558,12 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
clock := quartz.NewMock(t) clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey() resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err) require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour) resumeTokenProvider := newResumeTokenRecordingProvider(
coordinator := &resumeTokenTestFakeCoordinator{ t,
Coordinator: tailnet.NewCoordinator(logger), tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour),
t: t, )
peerIDCh: make(chan uuid.UUID, 1),
}
defer close(coordinator.peerIDCh)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{ client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: coordinator, Coordinator: tailnet.NewCoordinator(logger),
CoordinatorResumeTokenProvider: resumeTokenProvider, CoordinatorResumeTokenProvider: resumeTokenProvider,
}) })
defer closer.Close() defer closer.Close()
@ -576,7 +585,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
// random value. // random value.
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "") originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err) require.NoError(t, err)
originalPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh) originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, uuid.Nil) require.NotEqual(t, originalPeerID, uuid.Nil)
// Connect with a valid resume token, and ensure that the peer ID is set to // Connect with a valid resume token, and ensure that the peer ID is set to
@ -584,7 +593,9 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
clock.Advance(time.Second) clock.Advance(time.Second)
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken) newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err) require.NoError(t, err)
newPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh) verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, originalResumeToken, verifiedToken)
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.Equal(t, originalPeerID, newPeerID) require.Equal(t, originalPeerID, newPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken) require.NotEqual(t, originalResumeToken, newResumeToken)
@ -598,9 +609,11 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode()) require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
require.Len(t, sdkErr.Validations, 1) require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "resume_token", sdkErr.Validations[0].Field) require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, "invalid", verifiedToken)
select { select {
case <-coordinator.peerIDCh: case <-resumeTokenProvider.generateCalls:
t.Fatal("unexpected peer ID in channel") t.Fatal("unexpected peer ID in channel")
default: default:
} }
@ -646,21 +659,6 @@ func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Lo
return "", xerrors.Errorf("new dRPC client: %w", err) return "", xerrors.Errorf("new dRPC client: %w", err)
} }
// Send an empty coordination request. This will do nothing on the server,
// but ensures our wrapped coordinator can record the peer ID.
coordinateClient, err := rpcClient.Coordinate(ctx)
if err != nil {
return "", xerrors.Errorf("coordinate: %w", err)
}
err = coordinateClient.Send(&tailnetproto.CoordinateRequest{})
if err != nil {
return "", xerrors.Errorf("send empty coordination request: %w", err)
}
err = coordinateClient.Close()
if err != nil {
return "", xerrors.Errorf("close coordination request: %w", err)
}
// Fetch a resume token. // Fetch a resume token.
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{}) newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
if err != nil { if err != nil {

View File

@ -119,6 +119,8 @@ func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID
return xerrors.Errorf("yamux init failed: %w", err) return xerrors.Errorf("yamux init failed: %w", err)
} }
ctx = WithStreamID(ctx, streamID) ctx = WithStreamID(ctx, streamID)
s.Logger.Debug(ctx, "serving dRPC tailnet v2 API session",
slog.F("peer_id", streamID.ID.String()))
return s.drpc.Serve(ctx, session) return s.drpc.Serve(ctx, session)
} }