chore: refactor sending telemetry (#15345)

Implements a tailnet API Telemetry controller by refactoring from `workspacesdk`.

chore re: #14729
This commit is contained in:
Spike Curtis
2024-11-06 20:23:23 +04:00
committed by GitHub
parent 9126cd78a6
commit 335e4ab6bf
4 changed files with 254 additions and 186 deletions

View File

@ -8,16 +8,12 @@ import (
"net/http"
"net/url"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"storj.io/drpc"
"storj.io/drpc/drpcerr"
"cdr.dev/slog"
"github.com/coder/coder/v2/buildinfo"
@ -66,19 +62,12 @@ type tailnetAPIConnector struct {
dialOptions *websocket.DialOptions
derpCtrl tailnet.DERPController
coordCtrl tailnet.CoordinationController
customDialFn func() (proto.DRPCTailnetClient, error)
clientMu sync.RWMutex
client proto.DRPCTailnetClient
telCtrl *tailnet.BasicTelemetryController
connected chan error
resumeToken *proto.RefreshResumeTokenResponse
isFirst bool
closed chan struct{}
// Only set to true if we get a response from the server that it doesn't support
// network telemetry.
telemetryUnavailable atomic.Bool
}
// Create a new tailnetAPIConnector without running it
@ -92,6 +81,7 @@ func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uui
dialOptions: dialOptions,
connected: make(chan error, 1),
closed: make(chan struct{}),
telCtrl: tailnet.NewBasicTelemetryController(logger),
}
}
@ -124,9 +114,6 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
if err != nil {
continue
}
tac.clientMu.Lock()
tac.client = tailnetClient
tac.clientMu.Unlock()
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
tac.runConnectorOnce(tailnetClient)
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
@ -141,9 +128,6 @@ var permanentErrorStatuses = []int{
}
func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
if tac.customDialFn != nil {
return tac.customDialFn()
}
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
u, err := url.Parse(tac.coordinateURL)
@ -228,6 +212,8 @@ func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient)
}
}()
tac.telCtrl.New(client) // synchronous, doesn't need a goroutine
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
wg := sync.WaitGroup{}
wg.Add(3)
@ -245,10 +231,7 @@ func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient)
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
// close the underlying connection. This will trigger a retry of the control plane in
// run().
tac.clientMu.Lock()
client.DRPCConn().Close()
tac.client = nil
tac.clientMu.Unlock()
// Note that derpMap() logs it own errors, we don't bother here.
}
}()
@ -351,20 +334,5 @@ func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.D
}
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
tac.clientMu.RLock()
// We hold the lock for the entire telemetry request, but this would only block
// a coordinate retry, and closing the connection.
defer tac.clientMu.RUnlock()
if tac.client == nil || tac.telemetryUnavailable.Load() {
return
}
ctx, cancel := context.WithTimeout(tac.ctx, 5*time.Second)
defer cancel()
_, err := tac.client.PostTelemetry(ctx, &proto.TelemetryRequest{
Events: []*proto.TelemetryEvent{event},
})
if drpcerr.Code(err) == drpcerr.Unimplemented || drpc.ProtocolError.Has(err) && strings.Contains(err.Error(), "unknown rpc: ") {
tac.logger.Debug(tac.ctx, "attempted to send telemetry to a server that doesn't support it", slog.Error(err))
tac.telemetryUnavailable.Store(true)
}
tac.telCtrl.SendTelemetryEvent(event)
}

View File

@ -13,12 +13,8 @@ import (
"github.com/hashicorp/yamux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket"
"storj.io/drpc"
"storj.io/drpc/drpcerr"
"tailscale.com/tailcfg"
"cdr.dev/slog"
@ -385,7 +381,12 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
DERPMapUpdateFrequency: time.Millisecond,
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
testutil.RequireSendCtx(ctx, t, eventCh, batch)
select {
case <-ctx.Done():
t.Error("timeout sending telemetry event")
case eventCh <- batch:
t.Log("sent telemetry batch")
}
},
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
@ -409,11 +410,10 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
uut.runConnector(fConn)
require.Eventually(t, func() bool {
uut.clientMu.Lock()
defer uut.clientMu.Unlock()
return uut.client != nil
}, testutil.WaitShort, testutil.IntervalFast)
// Coordinate calls happen _after_ telemetry is connected up, so we use this
// to ensure telemetry is connected before sending our event
cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
defer close(cc.Resps)
uut.SendTelemetryEvent(&proto.TelemetryEvent{
Id: []byte("test event"),
@ -425,86 +425,6 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
require.Equal(t, []byte("test event"), testEvents[0].Id)
}
func TestTailnetAPIConnector_TelemetryUnimplemented(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID := uuid.UUID{0x55}
fConn := newFakeTailnetConn()
fakeDRPCClient := newFakeDRPCClient()
uut := &tailnetAPIConnector{
ctx: ctx,
logger: logger,
agentID: agentID,
coordinateURL: "",
clock: quartz.NewReal(),
dialOptions: &websocket.DialOptions{},
connected: make(chan error, 1),
closed: make(chan struct{}),
customDialFn: func() (proto.DRPCTailnetClient, error) {
return fakeDRPCClient, nil
},
}
uut.runConnector(fConn)
require.Eventually(t, func() bool {
uut.clientMu.Lock()
defer uut.clientMu.Unlock()
return uut.client != nil
}, testutil.WaitShort, testutil.IntervalFast)
fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), 0)
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
require.False(t, uut.telemetryUnavailable.Load())
require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
fakeDRPCClient.telemetryError = drpcerr.WithCode(xerrors.New("Unimplemented"), drpcerr.Unimplemented)
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
require.True(t, uut.telemetryUnavailable.Load())
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
}
func TestTailnetAPIConnector_TelemetryNotRecognised(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
agentID := uuid.UUID{0x55}
fConn := newFakeTailnetConn()
fakeDRPCClient := newFakeDRPCClient()
uut := &tailnetAPIConnector{
ctx: ctx,
logger: logger,
agentID: agentID,
coordinateURL: "",
clock: quartz.NewReal(),
dialOptions: &websocket.DialOptions{},
connected: make(chan error, 1),
closed: make(chan struct{}),
customDialFn: func() (proto.DRPCTailnetClient, error) {
return fakeDRPCClient, nil
},
}
uut.runConnector(fConn)
require.Eventually(t, func() bool {
uut.clientMu.Lock()
defer uut.clientMu.Unlock()
return uut.client != nil
}, testutil.WaitShort, testutil.IntervalFast)
fakeDRPCClient.telemetryError = drpc.ProtocolError.New("Protocol Error")
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
require.False(t, uut.telemetryUnavailable.Load())
require.Equal(t, int64(1), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
fakeDRPCClient.telemetryError = drpc.ProtocolError.New("unknown rpc: /coder.tailnet.v2.Tailnet/PostTelemetry")
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
require.True(t, uut.telemetryUnavailable.Load())
uut.SendTelemetryEvent(&proto.TelemetryEvent{})
require.Equal(t, int64(2), atomic.LoadInt64(&fakeDRPCClient.postTelemetryCalls))
}
type fakeTailnetConn struct{}
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
@ -524,65 +444,6 @@ func newFakeTailnetConn() *fakeTailnetConn {
return &fakeTailnetConn{}
}
type fakeDRPCClient struct {
postTelemetryCalls int64
refreshTokenFn func(context.Context, *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error)
telemetryError error
fakeDRPPCMapStream
}
var _ proto.DRPCTailnetClient = &fakeDRPCClient{}
func newFakeDRPCClient() *fakeDRPCClient {
return &fakeDRPCClient{
postTelemetryCalls: 0,
fakeDRPPCMapStream: fakeDRPPCMapStream{
fakeDRPCStream: fakeDRPCStream{
ch: make(chan struct{}),
},
},
}
}
// Coordinate implements proto.DRPCTailnetClient.
func (f *fakeDRPCClient) Coordinate(_ context.Context) (proto.DRPCTailnet_CoordinateClient, error) {
return &f.fakeDRPCStream, nil
}
// DRPCConn implements proto.DRPCTailnetClient.
func (*fakeDRPCClient) DRPCConn() drpc.Conn {
return &fakeDRPCConn{}
}
// PostTelemetry implements proto.DRPCTailnetClient.
func (f *fakeDRPCClient) PostTelemetry(_ context.Context, _ *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
atomic.AddInt64(&f.postTelemetryCalls, 1)
return nil, f.telemetryError
}
// StreamDERPMaps implements proto.DRPCTailnetClient.
func (f *fakeDRPCClient) StreamDERPMaps(_ context.Context, _ *proto.StreamDERPMapsRequest) (proto.DRPCTailnet_StreamDERPMapsClient, error) {
return &f.fakeDRPPCMapStream, nil
}
// RefreshResumeToken implements proto.DRPCTailnetClient.
func (f *fakeDRPCClient) RefreshResumeToken(_ context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
if f.refreshTokenFn != nil {
return f.refreshTokenFn(context.Background(), nil)
}
return &proto.RefreshResumeTokenResponse{
Token: "test",
RefreshIn: durationpb.New(30 * time.Minute),
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
}, nil
}
// WorkspaceUpdates implements proto.DRPCTailnetClient.
func (*fakeDRPCClient) WorkspaceUpdates(context.Context, *proto.WorkspaceUpdatesRequest) (proto.DRPCTailnet_WorkspaceUpdatesClient, error) {
panic("unimplemented")
}
type fakeDRPCConn struct{}
var _ drpc.Conn = &fakeDRPCConn{}