mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
651 lines
21 KiB
Go
651 lines
21 KiB
Go
package workspacesdk
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/yamux"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/xerrors"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
"nhooyr.io/websocket"
|
|
"storj.io/drpc"
|
|
"storj.io/drpc/drpcerr"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/apiversion"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func init() {
|
|
// Give tests a bit more time to timeout. Darwin is particularly slow.
|
|
tailnetConnectorGracefulTimeout = 5 * time.Second
|
|
}
|
|
|
|
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := testutil.Context(t, testutil.WaitShort)
|
|
ctx, cancel := context.WithCancel(testCtx)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs,
|
|
io.EOF, // we get EOF when we simulate a DERPMap error
|
|
yamux.ErrSessionShutdown, // coordination can throw these when DERP error tears down session
|
|
),
|
|
}).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
clientID := uuid.UUID{0x66}
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
defer close(derpMapCh)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Millisecond,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
|
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
|
Name: "client",
|
|
ID: clientID,
|
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
})
|
|
assert.NoError(t, err)
|
|
}))
|
|
|
|
fConn := newFakeTailnetConn()
|
|
|
|
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
|
uut.runConnector(fConn)
|
|
|
|
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
require.NotNil(t, reqTun.AddTunnel)
|
|
|
|
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
|
|
|
|
// simulate a problem with DERPMaps by sending nil
|
|
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
|
|
|
|
// this should cause the coordinate call to hang up WITHOUT disconnecting
|
|
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
require.Nil(t, reqNil)
|
|
|
|
// ...and then reconnect
|
|
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
require.NotNil(t, reqTun.AddTunnel)
|
|
|
|
// canceling the context should trigger the disconnect message
|
|
cancel()
|
|
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
|
require.NotNil(t, reqDisc)
|
|
require.NotNil(t, reqDisc.Disconnect)
|
|
}
|
|
|
|
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
|
|
|
|
// the following matches what Coderd does;
|
|
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
|
|
cVer := r.URL.Query().Get("version")
|
|
if err := sVer.Validate(cVer); err != nil {
|
|
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
|
|
Message: AgentAPIMismatchMessage,
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "version", Detail: err.Error()},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
}))
|
|
|
|
fConn := newFakeTailnetConn()
|
|
|
|
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
|
uut.runConnector(fConn)
|
|
|
|
err := testutil.RequireRecvCtx(ctx, t, uut.connected)
|
|
var sdkErr *codersdk.Error
|
|
require.ErrorAs(t, err, &sdkErr)
|
|
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
|
require.Equal(t, AgentAPIMismatchMessage, sdkErr.Message)
|
|
require.NotEmpty(t, sdkErr.Helper)
|
|
}
|
|
|
|
func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
defer close(derpMapCh)
|
|
|
|
clock := quartz.NewMock(t)
|
|
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
|
require.NoError(t, err)
|
|
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Millisecond,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
|
ResumeTokenProvider: resumeTokenProvider,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
websocketConnCh = make(chan *websocket.Conn, 64)
|
|
expectResumeToken = ""
|
|
)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Accept a resume_token query parameter to use the same peer ID. This
|
|
// behavior matches the actual client coordinate route.
|
|
var (
|
|
peerID = uuid.New()
|
|
resumeToken = r.URL.Query().Get("resume_token")
|
|
)
|
|
t.Logf("received resume token: %s", resumeToken)
|
|
assert.Equal(t, expectResumeToken, resumeToken)
|
|
if resumeToken != "" {
|
|
peerID, err = resumeTokenProvider.VerifyResumeToken(resumeToken)
|
|
assert.NoError(t, err, "failed to parse resume token")
|
|
if err != nil {
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
|
Message: CoordinateAPIInvalidResumeToken,
|
|
Detail: err.Error(),
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
|
|
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
|
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
|
Name: "client",
|
|
ID: peerID,
|
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
})
|
|
assert.NoError(t, err)
|
|
}))
|
|
|
|
fConn := newFakeTailnetConn()
|
|
|
|
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
|
|
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
|
|
defer newTickerTrap.Close()
|
|
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
|
|
uut.runConnector(fConn)
|
|
|
|
// Fetch first token. We don't need to advance the clock since we use a
|
|
// channel with a single item to immediately fetch.
|
|
newTickerTrap.MustWait(ctx).Release()
|
|
// We call ticker.Reset after each token fetch to apply the refresh duration
|
|
// requested by the server.
|
|
trappedReset := tickerResetTrap.MustWait(ctx)
|
|
trappedReset.Release()
|
|
require.NotNil(t, uut.resumeToken)
|
|
originalResumeToken := uut.resumeToken.Token
|
|
|
|
// Fetch second token.
|
|
waiter := clock.Advance(trappedReset.Duration)
|
|
waiter.MustWait(ctx)
|
|
trappedReset = tickerResetTrap.MustWait(ctx)
|
|
trappedReset.Release()
|
|
require.NotNil(t, uut.resumeToken)
|
|
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
|
|
expectResumeToken = uut.resumeToken.Token
|
|
t.Logf("expecting resume token: %s", expectResumeToken)
|
|
|
|
// Sever the connection and expect it to reconnect with the resume token.
|
|
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
|
|
_ = wsConn.Close(websocket.StatusGoingAway, "test")
|
|
|
|
// Wait for the resume token to be refreshed.
|
|
trappedTicker := newTickerTrap.MustWait(ctx)
|
|
// Advance the clock slightly to ensure the new JWT is different.
|
|
clock.Advance(time.Second).MustWait(ctx)
|
|
trappedTicker.Release()
|
|
trappedReset = tickerResetTrap.MustWait(ctx)
|
|
trappedReset.Release()
|
|
|
|
// The resume token should have changed again.
|
|
require.NotNil(t, uut.resumeToken)
|
|
require.NotEqual(t, expectResumeToken, uut.resumeToken.Token)
|
|
}
|
|
|
|
func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
IgnoreErrors: true,
|
|
}).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
defer close(derpMapCh)
|
|
|
|
clock := quartz.NewMock(t)
|
|
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
|
require.NoError(t, err)
|
|
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(resumeTokenSigningKey, clock, time.Hour)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Millisecond,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {},
|
|
ResumeTokenProvider: resumeTokenProvider,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
var (
|
|
websocketConnCh = make(chan *websocket.Conn, 64)
|
|
didFail int64
|
|
)
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Query().Get("resume_token") != "" {
|
|
atomic.AddInt64(&didFail, 1)
|
|
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
|
Message: CoordinateAPIInvalidResumeToken,
|
|
Validations: []codersdk.ValidationError{
|
|
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
|
|
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
|
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
|
Name: "client",
|
|
ID: uuid.New(),
|
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
})
|
|
assert.NoError(t, err)
|
|
}))
|
|
|
|
fConn := newFakeTailnetConn()
|
|
|
|
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
|
|
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
|
|
defer newTickerTrap.Close()
|
|
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
|
|
uut.runConnector(fConn)
|
|
|
|
// Wait for the resume token to be fetched for the first time.
|
|
newTickerTrap.MustWait(ctx).Release()
|
|
trappedReset := tickerResetTrap.MustWait(ctx)
|
|
trappedReset.Release()
|
|
originalResumeToken := uut.resumeToken.Token
|
|
|
|
// Sever the connection and expect it to reconnect with the resume token,
|
|
// which should fail and cause the client to be disconnected. The client
|
|
// should then reconnect with no resume token.
|
|
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
|
|
_ = wsConn.Close(websocket.StatusGoingAway, "test")
|
|
|
|
// Wait for the resume token to be refreshed, which indicates a successful
|
|
// reconnect.
|
|
trappedTicker := newTickerTrap.MustWait(ctx)
|
|
// Since we failed the initial reconnect and we're definitely reconnected
|
|
// now, the stored resume token should now be nil.
|
|
require.Nil(t, uut.resumeToken)
|
|
trappedTicker.Release()
|
|
trappedReset = tickerResetTrap.MustWait(ctx)
|
|
trappedReset.Release()
|
|
require.NotNil(t, uut.resumeToken)
|
|
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
|
|
|
|
// The resume token should have been rejected by the server.
|
|
require.EqualValues(t, 1, atomic.LoadInt64(&didFail))
|
|
}
|
|
|
|
func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
clientID := uuid.UUID{0x66}
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
defer close(derpMapCh)
|
|
eventCh := make(chan []*proto.TelemetryEvent, 1)
|
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
Logger: logger,
|
|
CoordPtr: &coordPtr,
|
|
DERPMapUpdateFrequency: time.Millisecond,
|
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
|
|
testutil.RequireSendCtx(ctx, t, eventCh, batch)
|
|
},
|
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
|
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
|
Name: "client",
|
|
ID: clientID,
|
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
})
|
|
assert.NoError(t, err)
|
|
}))
|
|
|
|
fConn := newFakeTailnetConn()
|
|
|
|
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
|
uut.runConnector(fConn)
|
|
require.Eventually(t, func() bool {
|
|
uut.clientMu.Lock()
|
|
defer uut.clientMu.Unlock()
|
|
return uut.client != nil
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{
|
|
Id: []byte("test event"),
|
|
})
|
|
|
|
testEvents := testutil.RequireRecvCtx(ctx, t, eventCh)
|
|
|
|
require.Len(t, testEvents, 1)
|
|
require.Equal(t, []byte("test event"), testEvents[0].Id)
|
|
}
|
|
|
|
func TestTailnetAPIConnector_TelemetryUnimplemented(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
fConn := newFakeTailnetConn()
|
|
|
|
fakeDRPCClient := newFakeDRPCClient()
|
|
uut := &tailnetAPIConnector{
|
|
ctx: ctx,
|
|
logger: logger,
|
|
agentID: agentID,
|
|
coordinateURL: "",
|
|
clock: quartz.NewReal(),
|
|
dialOptions: &websocket.DialOptions{},
|
|
conn: nil,
|
|
connected: make(chan error, 1),
|
|
closed: make(chan struct{}),
|
|
customDialFn: func() (proto.DRPCTailnetClient, error) {
|
|
return fakeDRPCClient, nil
|
|
},
|
|
}
|
|
uut.runConnector(fConn)
|
|
require.Eventually(t, func() bool {
|
|
uut.clientMu.Lock()
|
|
defer uut.clientMu.Unlock()
|
|
return uut.client != nil
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
|
|
fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), 0)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
require.False(t, uut.telemetryUnavailable.Load())
|
|
require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
|
|
|
|
fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), drpcerr.Unimplemented)
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
require.True(t, uut.telemetryUnavailable.Load())
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
|
|
}
|
|
|
|
func TestTailnetAPIConnector_TelemetryNotRecognised(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
fConn := newFakeTailnetConn()
|
|
|
|
fakeDRPCClient := newFakeDRPCClient()
|
|
uut := &tailnetAPIConnector{
|
|
ctx: ctx,
|
|
logger: logger,
|
|
agentID: agentID,
|
|
coordinateURL: "",
|
|
clock: quartz.NewReal(),
|
|
dialOptions: &websocket.DialOptions{},
|
|
conn: nil,
|
|
connected: make(chan error, 1),
|
|
closed: make(chan struct{}),
|
|
customDialFn: func() (proto.DRPCTailnetClient, error) {
|
|
return fakeDRPCClient, nil
|
|
},
|
|
}
|
|
uut.runConnector(fConn)
|
|
require.Eventually(t, func() bool {
|
|
uut.clientMu.Lock()
|
|
defer uut.clientMu.Unlock()
|
|
return uut.client != nil
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
|
|
fakeDRPCClient.telemetryError = drpc.ProtocolError.New("Protocol Error")
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
require.False(t, uut.telemetryUnavailable.Load())
|
|
require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
|
|
|
|
fakeDRPCClient.telemetryError = drpc.ProtocolError.New("unknown rpc: /coder.tailnet.v2.Tailnet/PostTelemetry")
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
require.True(t, uut.telemetryUnavailable.Load())
|
|
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
|
|
require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
|
|
}
|
|
|
|
type fakeTailnetConn struct{}
|
|
|
|
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
|
// TODO implement me
|
|
panic("implement me")
|
|
}
|
|
|
|
func (*fakeTailnetConn) SetAllPeersLost() {}
|
|
|
|
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
|
|
|
|
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
|
|
|
|
func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {}
|
|
|
|
func newFakeTailnetConn() *fakeTailnetConn {
|
|
return &fakeTailnetConn{}
|
|
}
|
|
|
|
type fakeDRPCClient struct {
|
|
postTelemetryCalls int64
|
|
refreshTokenFn func(context.Context, *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error)
|
|
telemetryError error
|
|
fakeDRPPCMapStream
|
|
}
|
|
|
|
var _ proto.DRPCTailnetClient = &fakeDRPCClient{}
|
|
|
|
func newFakeDRPCClient() *fakeDRPCClient {
|
|
return &fakeDRPCClient{
|
|
postTelemetryCalls: 0,
|
|
fakeDRPPCMapStream: fakeDRPPCMapStream{
|
|
fakeDRPCStream: fakeDRPCStream{
|
|
ch: make(chan struct{}),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// Coordinate implements proto.DRPCTailnetClient.
|
|
func (f *fakeDRPCClient) Coordinate(_ context.Context) (proto.DRPCTailnet_CoordinateClient, error) {
|
|
return &f.fakeDRPCStream, nil
|
|
}
|
|
|
|
// DRPCConn implements proto.DRPCTailnetClient.
|
|
func (*fakeDRPCClient) DRPCConn() drpc.Conn {
|
|
return &fakeDRPCConn{}
|
|
}
|
|
|
|
// PostTelemetry implements proto.DRPCTailnetClient.
|
|
func (f *fakeDRPCClient) PostTelemetry(_ context.Context, _ *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
|
|
atomic.AddInt64(&f.postTelemetryCalls, 1)
|
|
return nil, f.telemetryError
|
|
}
|
|
|
|
// StreamDERPMaps implements proto.DRPCTailnetClient.
|
|
func (f *fakeDRPCClient) StreamDERPMaps(_ context.Context, _ *proto.StreamDERPMapsRequest) (proto.DRPCTailnet_StreamDERPMapsClient, error) {
|
|
return &f.fakeDRPPCMapStream, nil
|
|
}
|
|
|
|
// RefreshResumeToken implements proto.DRPCTailnetClient.
|
|
func (f *fakeDRPCClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
|
|
if f.refreshTokenFn != nil {
|
|
return f.refreshTokenFn(context.Background(), nil)
|
|
}
|
|
|
|
return &proto.RefreshResumeTokenResponse{
|
|
Token: "test",
|
|
RefreshIn: durationpb.New(30 * time.Minute),
|
|
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
|
}, nil
|
|
}
|
|
|
|
type fakeDRPCConn struct{}
|
|
|
|
var _ drpc.Conn = &fakeDRPCConn{}
|
|
|
|
// Close implements drpc.Conn.
|
|
func (*fakeDRPCConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
// Closed implements drpc.Conn.
|
|
func (*fakeDRPCConn) Closed() <-chan struct{} {
|
|
return nil
|
|
}
|
|
|
|
// Invoke implements drpc.Conn.
|
|
func (*fakeDRPCConn) Invoke(_ context.Context, _ string, _ drpc.Encoding, _ drpc.Message, _ drpc.Message) error {
|
|
return nil
|
|
}
|
|
|
|
// NewStream implements drpc.Conn.
|
|
func (*fakeDRPCConn) NewStream(_ context.Context, _ string, _ drpc.Encoding) (drpc.Stream, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
type fakeDRPCStream struct {
|
|
ch chan struct{}
|
|
}
|
|
|
|
var _ proto.DRPCTailnet_CoordinateClient = &fakeDRPCStream{}
|
|
|
|
// Close implements proto.DRPCTailnet_CoordinateClient.
|
|
func (f *fakeDRPCStream) Close() error {
|
|
close(f.ch)
|
|
return nil
|
|
}
|
|
|
|
// CloseSend implements proto.DRPCTailnet_CoordinateClient.
|
|
func (*fakeDRPCStream) CloseSend() error {
|
|
return nil
|
|
}
|
|
|
|
// Context implements proto.DRPCTailnet_CoordinateClient.
|
|
func (*fakeDRPCStream) Context() context.Context {
|
|
return nil
|
|
}
|
|
|
|
// MsgRecv implements proto.DRPCTailnet_CoordinateClient.
|
|
func (*fakeDRPCStream) MsgRecv(_ drpc.Message, _ drpc.Encoding) error {
|
|
return nil
|
|
}
|
|
|
|
// MsgSend implements proto.DRPCTailnet_CoordinateClient.
|
|
func (*fakeDRPCStream) MsgSend(_ drpc.Message, _ drpc.Encoding) error {
|
|
return nil
|
|
}
|
|
|
|
// Recv implements proto.DRPCTailnet_CoordinateClient.
|
|
func (f *fakeDRPCStream) Recv() (*proto.CoordinateResponse, error) {
|
|
<-f.ch
|
|
return &proto.CoordinateResponse{}, nil
|
|
}
|
|
|
|
// Send implements proto.DRPCTailnet_CoordinateClient.
|
|
func (f *fakeDRPCStream) Send(*proto.CoordinateRequest) error {
|
|
<-f.ch
|
|
return nil
|
|
}
|
|
|
|
type fakeDRPPCMapStream struct {
|
|
fakeDRPCStream
|
|
}
|
|
|
|
var _ proto.DRPCTailnet_StreamDERPMapsClient = &fakeDRPPCMapStream{}
|
|
|
|
// Recv implements proto.DRPCTailnet_StreamDERPMapsClient.
|
|
func (f *fakeDRPPCMapStream) Recv() (*proto.DERPMap, error) {
|
|
<-f.fakeDRPCStream.ch
|
|
return &proto.DERPMap{}, nil
|
|
}
|