mirror of
https://github.com/coder/coder.git
synced 2025-07-21 01:28:49 +00:00
fix: fix flake in TestWorkspaceAgentClientCoordinate_ResumeToken (#14642)
fixes #14365 I bet what's going on is that in `connectToCoordinatorAndFetchResumeToken()` we call `Coordinate()`, send a message on the `Coordinate` client and then close it in rapid succession. We don't wait around for a response from the coordinator, so dRPC is likely aborting the call `Coordinate()` in the backend because the stream is closed before it even gets a chance. Instead of using the Coordinator to record the peer ID assigned on the API call, we can wrap the resume token provider, since we call that API _and_ wait for a response. This also affords the opportunity to directly assert we get called with the right token.
This commit is contained in:
@ -864,6 +864,8 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
api.Logger.Debug(ctx, "accepted coordinate resume token for peer",
|
||||||
|
slog.F("peer_id", peerID.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
api.WebsocketWaitMutex.Lock()
|
api.WebsocketWaitMutex.Lock()
|
||||||
|
@ -513,30 +513,42 @@ func TestWorkspaceAgentClientCoordinate_BadVersion(t *testing.T) {
|
|||||||
require.Equal(t, "version", sdkErr.Validations[0].Field)
|
require.Equal(t, "version", sdkErr.Validations[0].Field)
|
||||||
}
|
}
|
||||||
|
|
||||||
type resumeTokenTestFakeCoordinator struct {
|
type resumeTokenRecordingProvider struct {
|
||||||
tailnet.Coordinator
|
tailnet.ResumeTokenProvider
|
||||||
t testing.TB
|
t testing.TB
|
||||||
peerIDCh chan uuid.UUID
|
generateCalls chan uuid.UUID
|
||||||
|
verifyCalls chan string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ tailnet.Coordinator = &resumeTokenTestFakeCoordinator{}
|
var _ tailnet.ResumeTokenProvider = &resumeTokenRecordingProvider{}
|
||||||
|
|
||||||
func (c *resumeTokenTestFakeCoordinator) storeID(id uuid.UUID) {
|
func newResumeTokenRecordingProvider(t testing.TB, underlying tailnet.ResumeTokenProvider) *resumeTokenRecordingProvider {
|
||||||
select {
|
return &resumeTokenRecordingProvider{
|
||||||
case c.peerIDCh <- id:
|
ResumeTokenProvider: underlying,
|
||||||
default:
|
t: t,
|
||||||
c.t.Fatal("peer ID channel full")
|
generateCalls: make(chan uuid.UUID, 1),
|
||||||
|
verifyCalls: make(chan string, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *resumeTokenTestFakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agentID uuid.UUID) error {
|
func (r *resumeTokenRecordingProvider) GenerateResumeToken(peerID uuid.UUID) (*tailnetproto.RefreshResumeTokenResponse, error) {
|
||||||
c.storeID(id)
|
select {
|
||||||
return c.Coordinator.ServeClient(conn, id, agentID)
|
case r.generateCalls <- peerID:
|
||||||
|
return r.ResumeTokenProvider.GenerateResumeToken(peerID)
|
||||||
|
default:
|
||||||
|
r.t.Error("generateCalls full")
|
||||||
|
return nil, xerrors.New("generateCalls full")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *resumeTokenTestFakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.CoordinateeAuth) (chan<- *tailnetproto.CoordinateRequest, <-chan *tailnetproto.CoordinateResponse) {
|
func (r *resumeTokenRecordingProvider) VerifyResumeToken(token string) (uuid.UUID, error) {
|
||||||
c.storeID(id)
|
select {
|
||||||
return c.Coordinator.Coordinate(ctx, id, name, a)
|
case r.verifyCalls <- token:
|
||||||
|
return r.ResumeTokenProvider.VerifyResumeToken(token)
|
||||||
|
default:
|
||||||
|
r.t.Error("verifyCalls full")
|
||||||
|
return uuid.Nil, xerrors.New("verifyCalls full")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
||||||
@ -546,15 +558,12 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
|||||||
clock := quartz.NewMock(t)
|
clock := quartz.NewMock(t)
|
||||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
resumeTokenProvider := newResumeTokenRecordingProvider(
|
||||||
coordinator := &resumeTokenTestFakeCoordinator{
|
t,
|
||||||
Coordinator: tailnet.NewCoordinator(logger),
|
tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour),
|
||||||
t: t,
|
)
|
||||||
peerIDCh: make(chan uuid.UUID, 1),
|
|
||||||
}
|
|
||||||
defer close(coordinator.peerIDCh)
|
|
||||||
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||||
Coordinator: coordinator,
|
Coordinator: tailnet.NewCoordinator(logger),
|
||||||
CoordinatorResumeTokenProvider: resumeTokenProvider,
|
CoordinatorResumeTokenProvider: resumeTokenProvider,
|
||||||
})
|
})
|
||||||
defer closer.Close()
|
defer closer.Close()
|
||||||
@ -576,7 +585,7 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
|||||||
// random value.
|
// random value.
|
||||||
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
|
originalResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
originalPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh)
|
originalPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||||
require.NotEqual(t, originalPeerID, uuid.Nil)
|
require.NotEqual(t, originalPeerID, uuid.Nil)
|
||||||
|
|
||||||
// Connect with a valid resume token, and ensure that the peer ID is set to
|
// Connect with a valid resume token, and ensure that the peer ID is set to
|
||||||
@ -584,7 +593,9 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
|||||||
clock.Advance(time.Second)
|
clock.Advance(time.Second)
|
||||||
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
|
newResumeToken, err := connectToCoordinatorAndFetchResumeToken(ctx, logger, client, agentAndBuild.WorkspaceAgent.ID, originalResumeToken)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
newPeerID := testutil.RequireRecvCtx(ctx, t, coordinator.peerIDCh)
|
verifiedToken := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
|
||||||
|
require.Equal(t, originalResumeToken, verifiedToken)
|
||||||
|
newPeerID := testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.generateCalls)
|
||||||
require.Equal(t, originalPeerID, newPeerID)
|
require.Equal(t, originalPeerID, newPeerID)
|
||||||
require.NotEqual(t, originalResumeToken, newResumeToken)
|
require.NotEqual(t, originalResumeToken, newResumeToken)
|
||||||
|
|
||||||
@ -598,9 +609,11 @@ func TestWorkspaceAgentClientCoordinate_ResumeToken(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||||
require.Len(t, sdkErr.Validations, 1)
|
require.Len(t, sdkErr.Validations, 1)
|
||||||
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
|
require.Equal(t, "resume_token", sdkErr.Validations[0].Field)
|
||||||
|
verifiedToken = testutil.RequireRecvCtx(ctx, t, resumeTokenProvider.verifyCalls)
|
||||||
|
require.Equal(t, "invalid", verifiedToken)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-coordinator.peerIDCh:
|
case <-resumeTokenProvider.generateCalls:
|
||||||
t.Fatal("unexpected peer ID in channel")
|
t.Fatal("unexpected peer ID in channel")
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@ -646,21 +659,6 @@ func connectToCoordinatorAndFetchResumeToken(ctx context.Context, logger slog.Lo
|
|||||||
return "", xerrors.Errorf("new dRPC client: %w", err)
|
return "", xerrors.Errorf("new dRPC client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send an empty coordination request. This will do nothing on the server,
|
|
||||||
// but ensures our wrapped coordinator can record the peer ID.
|
|
||||||
coordinateClient, err := rpcClient.Coordinate(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return "", xerrors.Errorf("coordinate: %w", err)
|
|
||||||
}
|
|
||||||
err = coordinateClient.Send(&tailnetproto.CoordinateRequest{})
|
|
||||||
if err != nil {
|
|
||||||
return "", xerrors.Errorf("send empty coordination request: %w", err)
|
|
||||||
}
|
|
||||||
err = coordinateClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
return "", xerrors.Errorf("close coordination request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch a resume token.
|
// Fetch a resume token.
|
||||||
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
|
newResumeToken, err := rpcClient.RefreshResumeToken(ctx, &tailnetproto.RefreshResumeTokenRequest{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -119,6 +119,8 @@ func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID
|
|||||||
return xerrors.Errorf("yamux init failed: %w", err)
|
return xerrors.Errorf("yamux init failed: %w", err)
|
||||||
}
|
}
|
||||||
ctx = WithStreamID(ctx, streamID)
|
ctx = WithStreamID(ctx, streamID)
|
||||||
|
s.Logger.Debug(ctx, "serving dRPC tailnet v2 API session",
|
||||||
|
slog.F("peer_id", streamID.ID.String()))
|
||||||
return s.drpc.Serve(ctx, session)
|
return s.drpc.Serve(ctx, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user