From 335e4ab6bf87b3751ad48029f00af86b3f3669e9 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Wed, 6 Nov 2024 20:23:23 +0400 Subject: [PATCH] chore: refactor sending telemetry (#15345) Implements a tailnet API Telemetry controller by refactoring from `workspacesdk`. chore re: #14729 --- codersdk/workspacesdk/connector.go | 42 +---- .../workspacesdk/connector_internal_test.go | 159 ++---------------- tailnet/controllers.go | 83 +++++++++ tailnet/controllers_test.go | 156 +++++++++++++++++ 4 files changed, 254 insertions(+), 186 deletions(-) diff --git a/codersdk/workspacesdk/connector.go b/codersdk/workspacesdk/connector.go index dd4c9a0716..c50c2b0124 100644 --- a/codersdk/workspacesdk/connector.go +++ b/codersdk/workspacesdk/connector.go @@ -8,16 +8,12 @@ import ( "net/http" "net/url" "slices" - "strings" "sync" - "sync/atomic" "time" "github.com/google/uuid" "golang.org/x/xerrors" "nhooyr.io/websocket" - "storj.io/drpc" - "storj.io/drpc/drpcerr" "cdr.dev/slog" "github.com/coder/coder/v2/buildinfo" @@ -66,19 +62,12 @@ type tailnetAPIConnector struct { dialOptions *websocket.DialOptions derpCtrl tailnet.DERPController coordCtrl tailnet.CoordinationController - customDialFn func() (proto.DRPCTailnetClient, error) - - clientMu sync.RWMutex - client proto.DRPCTailnetClient + telCtrl *tailnet.BasicTelemetryController connected chan error resumeToken *proto.RefreshResumeTokenResponse isFirst bool closed chan struct{} - - // Only set to true if we get a response from the server that it doesn't support - // network telemetry. - telemetryUnavailable atomic.Bool } // Create a new tailnetAPIConnector without running it @@ -92,6 +81,7 @@ func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uui dialOptions: dialOptions, connected: make(chan error, 1), closed: make(chan struct{}), + telCtrl: tailnet.NewBasicTelemetryController(logger), } } @@ -124,9 +114,6 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) { if err != nil { continue } - tac.clientMu.Lock() - tac.client = tailnetClient - tac.clientMu.Unlock() tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client") tac.runConnectorOnce(tailnetClient) tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost") @@ -141,9 +128,6 @@ var permanentErrorStatuses = []int{ } func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) { - if tac.customDialFn != nil { - return tac.customDialFn() - } tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API") u, err := url.Parse(tac.coordinateURL) @@ -228,6 +212,8 @@ func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient) } }() + tac.telCtrl.New(client) // synchronous, doesn't need a goroutine + refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx) wg := sync.WaitGroup{} wg.Add(3) @@ -245,10 +231,7 @@ func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient) // we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just // close the underlying connection. This will trigger a retry of the control plane in // run(). - tac.clientMu.Lock() client.DRPCConn().Close() - tac.client = nil - tac.clientMu.Unlock() // Note that derpMap() logs it own errors, we don't bother here. } }() @@ -351,20 +334,5 @@ func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.D } func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) { - tac.clientMu.RLock() - // We hold the lock for the entire telemetry request, but this would only block - // a coordinate retry, and closing the connection. - defer tac.clientMu.RUnlock() - if tac.client == nil || tac.telemetryUnavailable.Load() { - return - } - ctx, cancel := context.WithTimeout(tac.ctx, 5*time.Second) - defer cancel() - _, err := tac.client.PostTelemetry(ctx, &proto.TelemetryRequest{ - Events: []*proto.TelemetryEvent{event}, - }) - if drpcerr.Code(err) == drpcerr.Unimplemented || drpc.ProtocolError.Has(err) && strings.Contains(err.Error(), "unknown rpc: ") { - tac.logger.Debug(tac.ctx, "attempted to send telemetry to a server that doesn't support it", slog.Error(err)) - tac.telemetryUnavailable.Store(true) - } + tac.telCtrl.SendTelemetryEvent(event) } diff --git a/codersdk/workspacesdk/connector_internal_test.go b/codersdk/workspacesdk/connector_internal_test.go index c3f22cd98b..88b857320c 100644 --- a/codersdk/workspacesdk/connector_internal_test.go +++ b/codersdk/workspacesdk/connector_internal_test.go @@ -13,12 +13,8 @@ import ( "github.com/hashicorp/yamux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" "nhooyr.io/websocket" "storj.io/drpc" - "storj.io/drpc/drpcerr" "tailscale.com/tailcfg" "cdr.dev/slog" @@ -385,7 +381,12 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { DERPMapUpdateFrequency: time.Millisecond, DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh }, NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) { - testutil.RequireSendCtx(ctx, t, eventCh, batch) + select { + case <-ctx.Done(): + t.Error("timeout sending telemetry event") + case eventCh <- batch: + t.Log("sent telemetry batch") + } }, ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(), }) @@ -409,11 +410,10 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{}) uut.runConnector(fConn) - require.Eventually(t, func() bool { - uut.clientMu.Lock() - defer uut.clientMu.Unlock() - return uut.client != nil - }, testutil.WaitShort, testutil.IntervalFast) + // Coordinate calls happen _after_ telemetry is connected up, so we use this + // to ensure telemetry is connected before sending our event + cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) + defer close(cc.Resps) uut.SendTelemetryEvent(&proto.TelemetryEvent{ Id: []byte("test event"), @@ -425,86 +425,6 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) { require.Equal(t, []byte("test event"), testEvents[0].Id) } -func TestTailnetAPIConnector_TelemetryUnimplemented(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - agentID := uuid.UUID{0x55} - fConn := newFakeTailnetConn() - - fakeDRPCClient := newFakeDRPCClient() - uut := &tailnetAPIConnector{ - ctx: ctx, - logger: logger, - agentID: agentID, - coordinateURL: "", - clock: quartz.NewReal(), - dialOptions: &websocket.DialOptions{}, - connected: make(chan error, 1), - closed: make(chan struct{}), - customDialFn: func() (proto.DRPCTailnetClient, error) { - return fakeDRPCClient, nil - }, - } - uut.runConnector(fConn) - require.Eventually(t, func() bool { - uut.clientMu.Lock() - defer uut.clientMu.Unlock() - return uut.client != nil - }, testutil.WaitShort, testutil.IntervalFast) - - fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), 0) - uut.SendTelemetryEvent(&proto.TelemetryEvent{}) - require.False(t, uut.telemetryUnavailable.Load()) - require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) - - fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), drpcerr.Unimplemented) - uut.SendTelemetryEvent(&proto.TelemetryEvent{}) - require.True(t, uut.telemetryUnavailable.Load()) - uut.SendTelemetryEvent(&proto.TelemetryEvent{}) - require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) -} - -func TestTailnetAPIConnector_TelemetryNotRecognised(t *testing.T) { - t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - agentID := uuid.UUID{0x55} - fConn := newFakeTailnetConn() - - fakeDRPCClient := newFakeDRPCClient() - uut := &tailnetAPIConnector{ - ctx: ctx, - logger: logger, - agentID: agentID, - coordinateURL: "", - clock: quartz.NewReal(), - dialOptions: &websocket.DialOptions{}, - connected: make(chan error, 1), - closed: make(chan struct{}), - customDialFn: func() (proto.DRPCTailnetClient, error) { - return fakeDRPCClient, nil - }, - } - uut.runConnector(fConn) - require.Eventually(t, func() bool { - uut.clientMu.Lock() - defer uut.clientMu.Unlock() - return uut.client != nil - }, testutil.WaitShort, testutil.IntervalFast) - - fakeDRPCClient.telemetryError = drpc.ProtocolError.New("Protocol Error") - uut.SendTelemetryEvent(&proto.TelemetryEvent{}) - require.False(t, uut.telemetryUnavailable.Load()) - require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) - - fakeDRPCClient.telemetryError = drpc.ProtocolError.New("unknown rpc: /coder.tailnet.v2.Tailnet/PostTelemetry") - uut.SendTelemetryEvent(&proto.TelemetryEvent{}) - require.True(t, uut.telemetryUnavailable.Load()) - uut.SendTelemetryEvent(&proto.TelemetryEvent{}) - require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls)) -} - type fakeTailnetConn struct{} func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error { @@ -524,65 +444,6 @@ func newFakeTailnetConn() *fakeTailnetConn { return &fakeTailnetConn{} } -type fakeDRPCClient struct { - postTelemetryCalls int64 - refreshTokenFn func(context.Context, *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) - telemetryError error - fakeDRPPCMapStream -} - -var _ proto.DRPCTailnetClient = &fakeDRPCClient{} - -func newFakeDRPCClient() *fakeDRPCClient { - return &fakeDRPCClient{ - postTelemetryCalls: 0, - fakeDRPPCMapStream: fakeDRPPCMapStream{ - fakeDRPCStream: fakeDRPCStream{ - ch: make(chan struct{}), - }, - }, - } -} - -// Coordinate implements proto.DRPCTailnetClient. -func (f *fakeDRPCClient) Coordinate(_ context.Context) (proto.DRPCTailnet_CoordinateClient, error) { - return &f.fakeDRPCStream, nil -} - -// DRPCConn implements proto.DRPCTailnetClient. -func (*fakeDRPCClient) DRPCConn() drpc.Conn { - return &fakeDRPCConn{} -} - -// PostTelemetry implements proto.DRPCTailnetClient. -func (f *fakeDRPCClient) PostTelemetry(_ context.Context, _ *proto.TelemetryRequest) (*proto.TelemetryResponse, error) { - atomic.AddInt64(&f.postTelemetryCalls, 1) - return nil, f.telemetryError -} - -// StreamDERPMaps implements proto.DRPCTailnetClient. -func (f *fakeDRPCClient) StreamDERPMaps(_ context.Context, _ *proto.StreamDERPMapsRequest) (proto.DRPCTailnet_StreamDERPMapsClient, error) { - return &f.fakeDRPPCMapStream, nil -} - -// RefreshResumeToken implements proto.DRPCTailnetClient. -func (f *fakeDRPCClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) { - if f.refreshTokenFn != nil { - return f.refreshTokenFn(context.Background(), nil) - } - - return &proto.RefreshResumeTokenResponse{ - Token: "test", - RefreshIn: durationpb.New(30 * time.Minute), - ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)), - }, nil -} - -// WorkspaceUpdates implements proto.DRPCTailnetClient. -func (*fakeDRPCClient) WorkspaceUpdates(context.Context, *proto.WorkspaceUpdatesRequest) (proto.DRPCTailnet_WorkspaceUpdatesClient, error) { - panic("unimplemented") -} - type fakeDRPCConn struct{} var _ drpc.Conn = &fakeDRPCConn{} diff --git a/tailnet/controllers.go b/tailnet/controllers.go index 95a483081f..3176d70129 100644 --- a/tailnet/controllers.go +++ b/tailnet/controllers.go @@ -4,11 +4,14 @@ import ( "context" "fmt" "io" + "strings" "sync" + "time" "github.com/google/uuid" "golang.org/x/xerrors" "storj.io/drpc" + "storj.io/drpc/drpcerr" "tailscale.com/tailcfg" "cdr.dev/slog" @@ -440,3 +443,83 @@ func (l *derpSetLoop) recvLoop() { l.setter.SetDERPMap(dm) } } + +type BasicTelemetryController struct { + logger slog.Logger + + sync.Mutex + client TelemetryClient + unavailable bool +} + +func (b *BasicTelemetryController) New(client TelemetryClient) { + b.Lock() + defer b.Unlock() + b.client = client + b.unavailable = false + b.logger.Debug(context.Background(), "new telemetry client connected to controller") +} + +func (b *BasicTelemetryController) SendTelemetryEvent(event *proto.TelemetryEvent) { + b.Lock() + if b.client == nil { + b.Unlock() + b.logger.Debug(context.Background(), + "telemetry event dropped; no client", slog.F("event", event)) + return + } + if b.unavailable { + b.Unlock() + b.logger.Debug(context.Background(), + "telemetry event dropped; unavailable", slog.F("event", event)) + return + } + client := b.client + b.Unlock() + unavailable := sendTelemetry(b.logger, client, event) + if unavailable { + b.Lock() + defer b.Unlock() + if b.client == client { + b.unavailable = true + } + } +} + +func NewBasicTelemetryController(logger slog.Logger) *BasicTelemetryController { + return &BasicTelemetryController{logger: logger} +} + +var ( + _ TelemetrySink = &BasicTelemetryController{} + _ TelemetryController = &BasicTelemetryController{} +) + +func sendTelemetry( + logger slog.Logger, client TelemetryClient, event *proto.TelemetryEvent, +) ( + unavailable bool, +) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := client.PostTelemetry(ctx, &proto.TelemetryRequest{ + Events: []*proto.TelemetryEvent{event}, + }) + if drpcerr.Code(err) == drpcerr.Unimplemented || + drpc.ProtocolError.Has(err) && + strings.Contains(err.Error(), "unknown rpc: ") { + logger.Debug( + context.Background(), + "attempted to send telemetry to a server that doesn't support it", + slog.Error(err), + ) + return true + } else if err != nil { + logger.Warn( + context.Background(), + "failed to post telemetry event", + slog.F("event", event), slog.Error(err), + ) + } + return false +} diff --git a/tailnet/controllers_test.go b/tailnet/controllers_test.go index a6d5b6ec0d..7c810af0c0 100644 --- a/tailnet/controllers_test.go +++ b/tailnet/controllers_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/xerrors" + "storj.io/drpc" + "storj.io/drpc/drpcerr" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -351,3 +353,157 @@ func (f fakeDERPClient) Recv() (*tailcfg.DERPMap, error) { } return nil, io.EOF } + +func TestBasicTelemetryController_Success(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + uut := tailnet.NewBasicTelemetryController(logger) + ft := newFakeTelemetryClient() + uut.New(ft) + + sendDone := make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{ + Id: []byte("test event"), + }) + }() + + call := testutil.RequireRecvCtx(ctx, t, ft.calls) + require.Len(t, call.req.GetEvents(), 1) + require.Equal(t, call.req.GetEvents()[0].GetId(), []byte("test event")) + + testutil.RequireSendCtx(ctx, t, call.errCh, nil) + testutil.RequireRecvCtx(ctx, t, sendDone) +} + +func TestBasicTelemetryController_Unimplemented(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + ft := newFakeTelemetryClient() + + uut := tailnet.NewBasicTelemetryController(logger) + uut.New(ft) + + // bad code, doesn't count + telemetryError := drpcerr.WithCode(xerrors.New("Unimplemented"), 0) + + sendDone := make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{}) + }() + + call := testutil.RequireRecvCtx(ctx, t, ft.calls) + testutil.RequireSendCtx(ctx, t, call.errCh, telemetryError) + testutil.RequireRecvCtx(ctx, t, sendDone) + + sendDone = make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{}) + }() + + // we get another call since it wasn't really the Unimplemented error + call = testutil.RequireRecvCtx(ctx, t, ft.calls) + + // for real this time + telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), drpcerr.Unimplemented) + testutil.RequireSendCtx(ctx, t, call.errCh, telemetryError) + testutil.RequireRecvCtx(ctx, t, sendDone) + + // now this returns immediately without a call, because unimplemented error disables calling + sendDone = make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{}) + }() + testutil.RequireRecvCtx(ctx, t, sendDone) + + // getting a "new" client resets + uut.New(ft) + sendDone = make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{}) + }() + call = testutil.RequireRecvCtx(ctx, t, ft.calls) + testutil.RequireSendCtx(ctx, t, call.errCh, nil) + testutil.RequireRecvCtx(ctx, t, sendDone) +} + +func TestBasicTelemetryController_NotRecognised(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ft := newFakeTelemetryClient() + uut := tailnet.NewBasicTelemetryController(logger) + uut.New(ft) + + sendDone := make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{}) + }() + // returning generic protocol error doesn't trigger unknown rpc logic + call := testutil.RequireRecvCtx(ctx, t, ft.calls) + testutil.RequireSendCtx(ctx, t, call.errCh, drpc.ProtocolError.New("Protocol Error")) + testutil.RequireRecvCtx(ctx, t, sendDone) + + sendDone = make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{}) + }() + call = testutil.RequireRecvCtx(ctx, t, ft.calls) + // return the expected protocol error this time + testutil.RequireSendCtx(ctx, t, call.errCh, + drpc.ProtocolError.New("unknown rpc: /coder.tailnet.v2.Tailnet/PostTelemetry")) + testutil.RequireRecvCtx(ctx, t, sendDone) + + // now this returns immediately without a call, because unimplemented error disables calling + sendDone = make(chan struct{}) + go func() { + defer close(sendDone) + uut.SendTelemetryEvent(&proto.TelemetryEvent{}) + }() + testutil.RequireRecvCtx(ctx, t, sendDone) +} + +type fakeTelemetryClient struct { + calls chan *fakeTelemetryCall +} + +var _ tailnet.TelemetryClient = &fakeTelemetryClient{} + +func newFakeTelemetryClient() *fakeTelemetryClient { + return &fakeTelemetryClient{ + calls: make(chan *fakeTelemetryCall), + } +} + +// PostTelemetry implements tailnet.TelemetryClient +func (f *fakeTelemetryClient) PostTelemetry(ctx context.Context, req *proto.TelemetryRequest) (*proto.TelemetryResponse, error) { + fr := &fakeTelemetryCall{req: req, errCh: make(chan error)} + select { + case <-ctx.Done(): + return nil, ctx.Err() + case f.calls <- fr: + // OK + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-fr.errCh: + return &proto.TelemetryResponse{}, err + } +} + +type fakeTelemetryCall struct { + req *proto.TelemetryRequest + errCh chan error +}