fix: fix flake in TestWorkspaceAgentClientCoordinate_ResumeToken (#14642)

fixes #14365

I bet what's going on is that in `connectToCoordinatorAndFetchResumeToken()` we call `Coordinate()`, send a message on the `Coordinate` client and then close it in rapid succession. We don't wait around for a response from the coordinator, so dRPC is likely aborting the call `Coordinate()` in the backend because the stream is closed before it even gets a chance.

Instead of using the Coordinator to record the peer ID assigned on the API call, we can wrap the resume token provider, since we call that API _and_ wait for a response. This also affords the opportunity to directly assert we get called with the right token.
This commit is contained in:
Spike Curtis
2024-09-11 16:32:47 +04:00
committed by GitHub
parent 1b5f3418d3
commit 5bd19f8ba3
3 changed files with 44 additions and 42 deletions

View File

@ -864,6 +864,8 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
})
return
}
api.Logger.Debug(ctx, "accepted coordinate resume token for peer",
slog.F("peer_id", peerID.String()))
}
api.WebsocketWaitMutex.Lock()

View File

@ -513,30 +513,42 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
require.Equal(t, "version", sdkErr.Validations[0].Field)
}
type resumeTokenTestFakeCoordinator struct {
tailnet.Coordinator
t testing.TB
peerIDCh chan uuid.UUID
type resumeTokenRecordingProvider struct {
tailnet.ResumeTokenProvider
t testing.TB
generateCalls chan uuid.UUID
verifyCalls chan string
}
var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{}
var _ tailnet.ResumeTokenProvider = &resumeTokenRecordingProvider{}
func (c *resumeTokenTestFakeCoordinator) storeID(id uuid.UUID) {
select {
case c.peerIDCh <- id:
default:
c.t.Fatal("peer ID channel full")
func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeTokenProvider) *resumeTokenRecordingProvider {
return &resumeTokenRecordingProvider{
ResumeTokenProvider: underlying,
t: t,
generateCalls: make(chan uuid.UUID, 1),
verifyCalls: make(chan string, 1),
}
}
func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error {
c.storeID(id)
return c.Coordinator.ServeClient(conn, id, agentID)
func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
select {
case r.generateCalls <- peerID:
return r.ResumeTokenProvider.GenerateResumeToken(peerID)
default:
r.t.Error("generateCalls full")
return nil, xerrors.New("generateCalls full")
}
}
func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) {
c.storeID(id)
return c.Coordinator.Coordinate(ctx, id, name, a)
func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) {
select {
case r.verifyCalls <- token:
return r.ResumeTokenProvider.VerifyResumeToken(token)
default:
r.t.Error("verifyCalls full")
return uuid.Nil, xerrors.New("verifyCalls full")
}
}
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
@ -546,15 +558,12 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
clock := quartz.NewMock(t)
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
require.NoError(t, err)
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
coordinator := &resumeTokenTestFakeCoordinator{
Coordinator: tailnet.NewCoordinator(logger),
t: t,
peerIDCh: make(chan uuid.UUID, 1),
}
defer close(coordinator.peerIDCh)
resumeTokenProvider := newResumeTokenRecordingProvider(
t,
tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour),
)
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Coordinator: coordinator,
Coordinator: tailnet.NewCoordinator(logger),
CoordinatorResumeTokenProvider: resumeTokenProvider,
})
defer closer.Close()
@ -576,7 +585,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
// random value.
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
require.NoError(t, err)
originalPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh)
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.NotEqual(t, originalPeerID, uuid.Nil)
// Connect with a valid resume token, and ensure that the peer ID is set to
@ -584,7 +593,9 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
clock.Advance(time.Second)
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
require.NoError(t, err)
newPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh)
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, originalResumeToken, verifiedToken)
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
require.Equal(t, originalPeerID, newPeerID)
require.NotEqual(t, originalResumeToken, newResumeToken)
@ -598,9 +609,11 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
require.Len(t, sdkErr.Validations, 1)
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
require.Equal(t, "invalid", verifiedToken)
select {
case <-coordinator.peerIDCh:
case <-resumeTokenProvider.generateCalls:
t.Fatal("unexpected peer ID in channel")
default:
}
@ -646,21 +659,6 @@ func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Lo
return "", xerrors.Errorf("new dRPC client: %w", err)
}
// Send an empty coordination request. This will do nothing on the server,
// but ensures our wrapped coordinator can record the peer ID.
coordinateClient, err := rpcClient.Coordinate(ctx)
if err != nil {
return "", xerrors.Errorf("coordinate: %w", err)
}
err = coordinateClient.Send(&tailnetproto.CoordinateRequest{})
if err != nil {
return "", xerrors.Errorf("send empty coordination request: %w", err)
}
err = coordinateClient.Close()
if err != nil {
return "", xerrors.Errorf("close coordination request: %w", err)
}
// Fetch a resume token.
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
if err != nil {