Files
coder/codersdk/workspacesdk/dialer_test.go
2025-04-16 10:37:09 -06:00

435 lines
12 KiB
Go

package workspacesdk_test
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)
func TestWebsocketDialer_TokenController(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
fTokenProv := newFakeTokenController(ctx, t)
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
})
require.NoError(t, err)
dialTokens := make(chan string, 1)
wsErr := make(chan error, 1)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-ctx.Done():
t.Error("timed out sending token")
case dialTokens <- r.URL.Query().Get("resume_token"):
// OK
}
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
// streamID can be empty because we don't call RPCs in this test.
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
}))
defer svr.Close()
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
clientCh := make(chan tailnet.ControlProtocolClients, 1)
go func() {
clients, err := uut.Dial(ctx, fTokenProv)
assert.NoError(t, err)
clientCh <- clients
}()
call := testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
call <- tokenResponse{"test token", true}
gotToken := <-dialTokens
require.Equal(t, "test token", gotToken)
clients := testutil.TryReceive(ctx, t, clientCh)
clients.Closer.Close()
err = testutil.TryReceive(ctx, t, wsErr)
require.NoError(t, err)
clientCh = make(chan tailnet.ControlProtocolClients, 1)
go func() {
clients, err := uut.Dial(ctx, fTokenProv)
assert.NoError(t, err)
clientCh <- clients
}()
call = testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
call <- tokenResponse{"test token", false}
gotToken = <-dialTokens
require.Equal(t, "", gotToken)
clients = testutil.TryReceive(ctx, t, clientCh)
require.Nil(t, clients.WorkspaceUpdates)
clients.Closer.Close()
err = testutil.TryReceive(ctx, t, wsErr)
require.NoError(t, err)
}
func TestWebsocketDialer_NoTokenController(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
})
require.NoError(t, err)
dialTokens := make(chan string, 1)
wsErr := make(chan error, 1)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-ctx.Done():
t.Error("timed out sending token")
case dialTokens <- r.URL.Query().Get("resume_token"):
// OK
}
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
// streamID can be empty because we don't call RPCs in this test.
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
}))
defer svr.Close()
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
clientCh := make(chan tailnet.ControlProtocolClients, 1)
go func() {
clients, err := uut.Dial(ctx, nil)
assert.NoError(t, err)
clientCh <- clients
}()
gotToken := <-dialTokens
require.Equal(t, "", gotToken)
clients := testutil.TryReceive(ctx, t, clientCh)
clients.Closer.Close()
err = testutil.TryReceive(ctx, t, wsErr)
require.NoError(t, err)
}
func TestWebsocketDialer_ResumeTokenFailure(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
fTokenProv := newFakeTokenController(ctx, t)
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
})
require.NoError(t, err)
dialTokens := make(chan string, 1)
wsErr := make(chan error, 1)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resumeToken := r.URL.Query().Get("resume_token")
select {
case <-ctx.Done():
t.Error("timed out sending token")
case dialTokens <- resumeToken:
// OK
}
if resumeToken != "" {
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
Validations: []codersdk.ValidationError{
{Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken},
},
})
return
}
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
// streamID can be empty because we don't call RPCs in this test.
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
}))
defer svr.Close()
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
errCh := make(chan error, 1)
go func() {
_, err := uut.Dial(ctx, fTokenProv)
errCh <- err
}()
call := testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
call <- tokenResponse{"test token", true}
gotToken := <-dialTokens
require.Equal(t, "test token", gotToken)
err = testutil.TryReceive(ctx, t, errCh)
require.Error(t, err)
// redial should not use the token
clientCh := make(chan tailnet.ControlProtocolClients, 1)
go func() {
clients, err := uut.Dial(ctx, fTokenProv)
assert.NoError(t, err)
clientCh <- clients
}()
gotToken = <-dialTokens
require.Equal(t, "", gotToken)
clients := testutil.TryReceive(ctx, t, clientCh)
require.Error(t, err)
clients.Closer.Close()
err = testutil.TryReceive(ctx, t, wsErr)
require.NoError(t, err)
// Successful dial should reset to using token again
go func() {
_, err := uut.Dial(ctx, fTokenProv)
errCh <- err
}()
call = testutil.TryReceive(ctx, t, fTokenProv.tokenCalls)
call <- tokenResponse{"test token", true}
gotToken = <-dialTokens
require.Equal(t, "test token", gotToken)
err = testutil.TryReceive(ctx, t, errCh)
require.Error(t, err)
}
func TestWebsocketDialer_UplevelVersion(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sVer := apiversion.New(2, 2)
// the following matches what Coderd does;
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
cVer := r.URL.Query().Get("version")
if err := sVer.Validate(cVer); err != nil {
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
Message: workspacesdk.AgentAPIMismatchMessage,
Validations: []codersdk.ValidationError{
{Field: "version", Detail: err.Error()},
},
})
return
}
}))
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)
uut := workspacesdk.NewWebsocketDialer(
logger, svrURL, &websocket.DialOptions{},
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{}),
)
errCh := make(chan error, 1)
go func() {
_, err := uut.Dial(ctx, nil)
errCh <- err
}()
err = testutil.TryReceive(ctx, t, errCh)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
require.Equal(t, workspacesdk.AgentAPIMismatchMessage, sdkErr.Message)
require.NotEmpty(t, sdkErr.Helper)
}
func TestWebsocketDialer_WorkspaceUpdates(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
ctrl := gomock.NewController(t)
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
WorkspaceUpdatesProvider: mProvider,
})
require.NoError(t, err)
wsErr := make(chan error, 1)
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// need 2.3 for WorkspaceUpdates RPC
cVer := r.URL.Query().Get("version")
assert.Equal(t, "2.3", cVer)
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
// streamID can be empty because we don't call RPCs in this test.
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
}))
defer svr.Close()
svrURL, err := url.Parse(svr.URL)
require.NoError(t, err)
userID := uuid.UUID{88}
mSub := tailnettest.NewMockSubscription(ctrl)
updateCh := make(chan *tailnetproto.WorkspaceUpdate, 1)
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
mSub.EXPECT().Updates().MinTimes(1).Return(updateCh)
mSub.EXPECT().Close().Times(1).Return(nil)
uut := workspacesdk.NewWebsocketDialer(
logger, svrURL, &websocket.DialOptions{},
workspacesdk.WithWorkspaceUpdates(&tailnetproto.WorkspaceUpdatesRequest{
WorkspaceOwnerId: userID[:],
}),
)
clients, err := uut.Dial(ctx, nil)
require.NoError(t, err)
require.NotNil(t, clients.WorkspaceUpdates)
wsID := uuid.UUID{99}
expectedUpdate := &tailnetproto.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnetproto.Workspace{
{Id: wsID[:]},
},
}
updateCh <- expectedUpdate
gotUpdate, err := clients.WorkspaceUpdates.Recv()
require.NoError(t, err)
require.Equal(t, wsID[:], gotUpdate.GetUpsertedWorkspaces()[0].GetId())
clients.Closer.Close()
err = testutil.TryReceive(ctx, t, wsErr)
require.NoError(t, err)
}
type fakeResumeTokenController struct {
ctx context.Context
t testing.TB
tokenCalls chan chan tokenResponse
}
func (*fakeResumeTokenController) New(tailnet.ResumeTokenClient) tailnet.CloserWaiter {
panic("not implemented")
}
func (f *fakeResumeTokenController) Token() (string, bool) {
call := make(chan tokenResponse)
select {
case <-f.ctx.Done():
f.t.Error("timeout on Token() call")
case f.tokenCalls <- call:
// OK
}
select {
case <-f.ctx.Done():
f.t.Error("timeout on Token() response")
return "", false
case r := <-call:
return r.token, r.ok
}
}
var _ tailnet.ResumeTokenController = &fakeResumeTokenController{}
func newFakeTokenController(ctx context.Context, t testing.TB) *fakeResumeTokenController {
return &fakeResumeTokenController{
ctx: ctx,
t: t,
tokenCalls: make(chan chan tokenResponse),
}
}
type tokenResponse struct {
token string
ok bool
}