package workspacesdk_test import ( "context" "net/http" "net/http/httptest" "net/url" "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "tailscale.com/tailcfg" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/apiversion" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/tailnet" tailnetproto "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" "github.com/coder/websocket" ) func TestWebsocketDialer_TokenController(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitMedium) logger := slogtest.Make(t, &slogtest.Options{ IgnoreErrors: true, }).Leveled(slog.LevelDebug) fTokenProv := newFakeTokenController(ctx, t) fCoord := tailnettest.NewFakeCoordinator() var coord tailnet.Coordinator = fCoord coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coord) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Hour, DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} }, }) require.NoError(t, err) dialTokens := make(chan string, 1) wsErr := make(chan error, 1) svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-ctx.Done(): t.Error("timed out sending token") case dialTokens <- r.URL.Query().Get("resume_token"): // OK } sws, err := websocket.Accept(w, r, nil) if !assert.NoError(t, err) { return } wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary) // streamID can be empty because we don't call RPCs in this test. wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{}) })) defer svr.Close() svrURL, err := url.Parse(svr.URL) require.NoError(t, err) uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{}) clientCh := make(chan tailnet.ControlProtocolClients, 1) go func() { clients, err := uut.Dial(ctx, fTokenProv) assert.NoError(t, err) clientCh <- clients }() call := testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) call <- tokenResponse{"test token", true} gotToken := <-dialTokens require.Equal(t, "test token", gotToken) clients := testutil.RequireRecvCtx(ctx, t, clientCh) clients.Closer.Close() err = testutil.RequireRecvCtx(ctx, t, wsErr) require.NoError(t, err) clientCh = make(chan tailnet.ControlProtocolClients, 1) go func() { clients, err := uut.Dial(ctx, fTokenProv) assert.NoError(t, err) clientCh <- clients }() call = testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) call <- tokenResponse{"test token", false} gotToken = <-dialTokens require.Equal(t, "", gotToken) clients = testutil.RequireRecvCtx(ctx, t, clientCh) require.Nil(t, clients.WorkspaceUpdates) clients.Closer.Close() err = testutil.RequireRecvCtx(ctx, t, wsErr) require.NoError(t, err) } func TestWebsocketDialer_NoTokenController(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, &slogtest.Options{ IgnoreErrors: true, }).Leveled(slog.LevelDebug) fCoord := tailnettest.NewFakeCoordinator() var coord tailnet.Coordinator = fCoord coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coord) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Hour, DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} }, }) require.NoError(t, err) dialTokens := make(chan string, 1) wsErr := make(chan error, 1) svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-ctx.Done(): t.Error("timed out sending token") case dialTokens <- r.URL.Query().Get("resume_token"): // OK } sws, err := websocket.Accept(w, r, nil) if !assert.NoError(t, err) { return } wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary) // streamID can be empty because we don't call RPCs in this test. wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{}) })) defer svr.Close() svrURL, err := url.Parse(svr.URL) require.NoError(t, err) uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{}) clientCh := make(chan tailnet.ControlProtocolClients, 1) go func() { clients, err := uut.Dial(ctx, nil) assert.NoError(t, err) clientCh <- clients }() gotToken := <-dialTokens require.Equal(t, "", gotToken) clients := testutil.RequireRecvCtx(ctx, t, clientCh) clients.Closer.Close() err = testutil.RequireRecvCtx(ctx, t, wsErr) require.NoError(t, err) } func TestWebsocketDialer_ResumeTokenFailure(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, &slogtest.Options{ IgnoreErrors: true, }).Leveled(slog.LevelDebug) fTokenProv := newFakeTokenController(ctx, t) fCoord := tailnettest.NewFakeCoordinator() var coord tailnet.Coordinator = fCoord coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coord) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Hour, DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} }, }) require.NoError(t, err) dialTokens := make(chan string, 1) wsErr := make(chan error, 1) svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resumeToken := r.URL.Query().Get("resume_token") select { case <-ctx.Done(): t.Error("timed out sending token") case dialTokens <- resumeToken: // OK } if resumeToken != "" { httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{ Message: workspacesdk.CoordinateAPIInvalidResumeToken, Validations: []codersdk.ValidationError{ {Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken}, }, }) return } sws, err := websocket.Accept(w, r, nil) if !assert.NoError(t, err) { return } wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary) // streamID can be empty because we don't call RPCs in this test. wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{}) })) defer svr.Close() svrURL, err := url.Parse(svr.URL) require.NoError(t, err) uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{}) errCh := make(chan error, 1) go func() { _, err := uut.Dial(ctx, fTokenProv) errCh <- err }() call := testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) call <- tokenResponse{"test token", true} gotToken := <-dialTokens require.Equal(t, "test token", gotToken) err = testutil.RequireRecvCtx(ctx, t, errCh) require.Error(t, err) // redial should not use the token clientCh := make(chan tailnet.ControlProtocolClients, 1) go func() { clients, err := uut.Dial(ctx, fTokenProv) assert.NoError(t, err) clientCh <- clients }() gotToken = <-dialTokens require.Equal(t, "", gotToken) clients := testutil.RequireRecvCtx(ctx, t, clientCh) require.Error(t, err) clients.Closer.Close() err = testutil.RequireRecvCtx(ctx, t, wsErr) require.NoError(t, err) // Successful dial should reset to using token again go func() { _, err := uut.Dial(ctx, fTokenProv) errCh <- err }() call = testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls) call <- tokenResponse{"test token", true} gotToken = <-dialTokens require.Equal(t, "test token", gotToken) err = testutil.RequireRecvCtx(ctx, t, errCh) require.Error(t, err) } func TestWebsocketDialer_UplevelVersion(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sVer := apiversion.New(2, 2) // the following matches what Coderd does; // c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate cVer := r.URL.Query().Get("version") if err := sVer.Validate(cVer); err != nil { httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{ Message: workspacesdk.AgentAPIMismatchMessage, Validations: []codersdk.ValidationError{ {Field: "version", Detail: err.Error()}, }, }) return } })) svrURL, err := url.Parse(svr.URL) require.NoError(t, err) uut := workspacesdk.NewWebsocketDialer( logger, svrURL, &websocket.DialOptions{}, workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}), ) errCh := make(chan error, 1) go func() { _, err := uut.Dial(ctx, nil) errCh <- err }() err = testutil.RequireRecvCtx(ctx, t, errCh) var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) require.Equal(t, workspacesdk.AgentAPIMismatchMessage, sdkErr.Message) require.NotEmpty(t, sdkErr.Helper) } func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) logger := slogtest.Make(t, &slogtest.Options{ IgnoreErrors: true, }).Leveled(slog.LevelDebug) fCoord := tailnettest.NewFakeCoordinator() var coord tailnet.Coordinator = fCoord coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coord) ctrl := gomock.NewController(t) mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl) svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{ Logger: logger, CoordPtr: &coordPtr, DERPMapUpdateFrequency: time.Hour, DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} }, WorkspaceUpdatesProvider: mProvider, }) require.NoError(t, err) wsErr := make(chan error, 1) svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // need 2.3 for WorkspaceUpdates RPC cVer := r.URL.Query().Get("version") assert.Equal(t, "2.3", cVer) sws, err := websocket.Accept(w, r, nil) if !assert.NoError(t, err) { return } wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary) // streamID can be empty because we don't call RPCs in this test. wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{}) })) defer svr.Close() svrURL, err := url.Parse(svr.URL) require.NoError(t, err) userID := uuid.UUID{88} mSub := tailnettest.NewMockSubscription(ctrl) updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1) mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil) mSub.EXPECT().Updates().MinTimes(1).Return(updateCh) mSub.EXPECT().Close().Times(1).Return(nil) uut := workspacesdk.NewWebsocketDialer( logger, svrURL, &websocket.DialOptions{}, workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{ WorkspaceOwnerId: userID[:], }), ) clients, err := uut.Dial(ctx, nil) require.NoError(t, err) require.NotNil(t, clients.WorkspaceUpdates) wsID := uuid.UUID{99} expectedUpdate := &tailnetproto.WorkspaceUpdate{ UpsertedWorkspaces: []*tailnetproto.Workspace{ {Id: wsID[:]}, }, } updateCh <- expectedUpdate gotUpdate, err := clients.WorkspaceUpdates.Recv() require.NoError(t, err) require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId()) clients.Closer.Close() err = testutil.RequireRecvCtx(ctx, t, wsErr) require.NoError(t, err) } type fakeResumeTokenController struct { ctx context.Context t testing.TB tokenCalls chan chan tokenResponse } func (*fakeResumeTokenController) New(tailnet.ResumeTokenClient) tailnet.CloserWaiter { panic("not implemented") } func (f *fakeResumeTokenController) Token() (string, bool) { call := make(chan tokenResponse) select { case <-f.ctx.Done(): f.t.Error("timeout on Token() call") case f.tokenCalls <- call: // OK } select { case <-f.ctx.Done(): f.t.Error("timeout on Token() response") return "", false case r := <-call: return r.token, r.ok } } var _ tailnet.ResumeTokenController = &fakeResumeTokenController{} func newFakeTokenController(ctx context.Context, t testing.TB) *fakeResumeTokenController { return &fakeResumeTokenController{ ctx: ctx, t: t, tokenCalls: make(chan chan tokenResponse), } } type tokenResponse struct { token string ok bool }