mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
347 lines
9.1 KiB
Go
347 lines
9.1 KiB
Go
package tailnet
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/yamux"
|
|
"golang.org/x/xerrors"
|
|
"storj.io/drpc/drpcmux"
|
|
"storj.io/drpc/drpcserver"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"github.com/coder/coder/v2/apiversion"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
type streamIDContextKey struct{}
|
|
|
|
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
|
// on the context, since the information is extracted at the HTTP layer for
|
|
// remote clients of the API, or set outside tailnet for local clients (e.g.
|
|
// Coderd's single_tailnet)
|
|
type StreamID struct {
|
|
Name string
|
|
ID uuid.UUID
|
|
Auth CoordinateeAuth
|
|
}
|
|
|
|
func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
|
|
return context.WithValue(ctx, streamIDContextKey{}, streamID)
|
|
}
|
|
|
|
type ClientServiceOptions struct {
|
|
Logger slog.Logger
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
|
DERPMapUpdateFrequency time.Duration
|
|
DERPMapFn func() *tailcfg.DERPMap
|
|
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
|
|
}
|
|
|
|
// ClientService is a tailnet coordination service that accepts a connection and version from a
|
|
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
|
|
type ClientService struct {
|
|
Logger slog.Logger
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
|
drpc *drpcserver.Server
|
|
}
|
|
|
|
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
|
|
// loaded on each processed connection.
|
|
func NewClientService(options ClientServiceOptions) (
|
|
*ClientService, error,
|
|
) {
|
|
s := &ClientService{Logger: options.Logger, CoordPtr: options.CoordPtr}
|
|
mux := drpcmux.New()
|
|
drpcService := &DRPCService{
|
|
CoordPtr: options.CoordPtr,
|
|
Logger: options.Logger,
|
|
DerpMapUpdateFrequency: options.DERPMapUpdateFrequency,
|
|
DerpMapFn: options.DERPMapFn,
|
|
NetworkTelemetryHandler: options.NetworkTelemetryHandler,
|
|
}
|
|
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("register DRPC service: %w", err)
|
|
}
|
|
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
|
Log: func(err error) {
|
|
if xerrors.Is(err, io.EOF) ||
|
|
xerrors.Is(err, context.Canceled) ||
|
|
xerrors.Is(err, context.DeadlineExceeded) {
|
|
return
|
|
}
|
|
options.Logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
|
},
|
|
})
|
|
s.drpc = server
|
|
return s, nil
|
|
}
|
|
|
|
func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
|
major, _, err := apiversion.Parse(version)
|
|
if err != nil {
|
|
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
|
|
return err
|
|
}
|
|
switch major {
|
|
case 1:
|
|
coord := *(s.CoordPtr.Load())
|
|
return coord.ServeClient(conn, id, agent)
|
|
case 2:
|
|
auth := ClientCoordinateeAuth{AgentID: agent}
|
|
streamID := StreamID{
|
|
Name: "client",
|
|
ID: id,
|
|
Auth: auth,
|
|
}
|
|
return s.ServeConnV2(ctx, conn, streamID)
|
|
default:
|
|
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
|
return xerrors.New("unsupported version")
|
|
}
|
|
}
|
|
|
|
func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID StreamID) error {
|
|
config := yamux.DefaultConfig()
|
|
config.LogOutput = io.Discard
|
|
session, err := yamux.Server(conn, config)
|
|
if err != nil {
|
|
return xerrors.Errorf("yamux init failed: %w", err)
|
|
}
|
|
ctx = WithStreamID(ctx, streamID)
|
|
return s.drpc.Serve(ctx, session)
|
|
}
|
|
|
|
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
|
|
type DRPCService struct {
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
|
Logger slog.Logger
|
|
DerpMapUpdateFrequency time.Duration
|
|
DerpMapFn func() *tailcfg.DERPMap
|
|
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
|
|
}
|
|
|
|
func (s *DRPCService) PostTelemetry(_ context.Context, req *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
|
|
if s.NetworkTelemetryHandler != nil {
|
|
s.NetworkTelemetryHandler(req.Events)
|
|
}
|
|
return &proto.TelemetryResponse{}, nil
|
|
}
|
|
|
|
func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream proto.DRPCTailnet_StreamDERPMapsStream) error {
|
|
defer stream.Close()
|
|
|
|
ticker := time.NewTicker(s.DerpMapUpdateFrequency)
|
|
defer ticker.Stop()
|
|
|
|
var lastDERPMap *tailcfg.DERPMap
|
|
for {
|
|
derpMap := s.DerpMapFn()
|
|
if derpMap == nil {
|
|
// in testing, we send nil to close the stream.
|
|
return io.EOF
|
|
}
|
|
if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) {
|
|
protoDERPMap := DERPMapToProto(derpMap)
|
|
err := stream.Send(protoDERPMap)
|
|
if err != nil {
|
|
return xerrors.Errorf("send derp map: %w", err)
|
|
}
|
|
lastDERPMap = derpMap
|
|
}
|
|
|
|
ticker.Reset(s.DerpMapUpdateFrequency)
|
|
select {
|
|
case <-stream.Context().Done():
|
|
return nil
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *DRPCService) Coordinate(stream proto.DRPCTailnet_CoordinateStream) error {
|
|
ctx := stream.Context()
|
|
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
|
|
if !ok {
|
|
_ = stream.Close()
|
|
return xerrors.New("no Stream ID")
|
|
}
|
|
logger := s.Logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
|
|
logger.Debug(ctx, "starting tailnet Coordinate")
|
|
coord := *(s.CoordPtr.Load())
|
|
reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
|
|
c := communicator{
|
|
logger: logger,
|
|
stream: stream,
|
|
reqs: reqs,
|
|
resps: resps,
|
|
}
|
|
c.communicate()
|
|
return nil
|
|
}
|
|
|
|
type communicator struct {
|
|
logger slog.Logger
|
|
stream proto.DRPCTailnet_CoordinateStream
|
|
reqs chan<- *proto.CoordinateRequest
|
|
resps <-chan *proto.CoordinateResponse
|
|
}
|
|
|
|
func (c communicator) communicate() {
|
|
go c.loopReq()
|
|
c.loopResp()
|
|
}
|
|
|
|
func (c communicator) loopReq() {
|
|
ctx := c.stream.Context()
|
|
defer close(c.reqs)
|
|
for {
|
|
req, err := c.stream.Recv()
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err))
|
|
return
|
|
}
|
|
err = SendCtx(ctx, c.reqs, req)
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err()))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c communicator) loopResp() {
|
|
ctx := c.stream.Context()
|
|
defer func() {
|
|
err := c.stream.Close()
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err))
|
|
}
|
|
}()
|
|
for {
|
|
resp, err := RecvCtx(ctx, c.resps)
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err))
|
|
return
|
|
}
|
|
err = c.stream.Send(resp)
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type NetworkTelemetryBatcher struct {
|
|
clock quartz.Clock
|
|
frequency time.Duration
|
|
maxSize int
|
|
batchFn func(batch []*proto.TelemetryEvent)
|
|
|
|
mu sync.Mutex
|
|
closed chan struct{}
|
|
done chan struct{}
|
|
ticker *quartz.Ticker
|
|
pending []*proto.TelemetryEvent
|
|
}
|
|
|
|
func NewNetworkTelemetryBatcher(clk quartz.Clock, frequency time.Duration, maxSize int, batchFn func(batch []*proto.TelemetryEvent)) *NetworkTelemetryBatcher {
|
|
b := &NetworkTelemetryBatcher{
|
|
clock: clk,
|
|
frequency: frequency,
|
|
maxSize: maxSize,
|
|
batchFn: batchFn,
|
|
closed: make(chan struct{}),
|
|
done: make(chan struct{}),
|
|
}
|
|
if b.batchFn == nil {
|
|
b.batchFn = func(batch []*proto.TelemetryEvent) {}
|
|
}
|
|
b.start()
|
|
return b
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) Close() error {
|
|
close(b.closed)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
select {
|
|
case <-ctx.Done():
|
|
return xerrors.New("timed out waiting for batcher to close")
|
|
case <-b.done:
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) sendTelemetryBatch() {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
events := b.pending
|
|
if len(events) == 0 {
|
|
return
|
|
}
|
|
b.pending = []*proto.TelemetryEvent{}
|
|
b.batchFn(events)
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) start() {
|
|
b.ticker = b.clock.NewTicker(b.frequency)
|
|
|
|
go func() {
|
|
defer func() {
|
|
// The lock prevents Handler from racing with Close.
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
close(b.done)
|
|
b.ticker.Stop()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-b.ticker.C:
|
|
b.sendTelemetryBatch()
|
|
b.ticker.Reset(b.frequency)
|
|
case <-b.closed:
|
|
// Send any remaining telemetry events before exiting.
|
|
b.sendTelemetryBatch()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) Handler(events []*proto.TelemetryEvent) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
select {
|
|
case <-b.closed:
|
|
return
|
|
default:
|
|
}
|
|
|
|
for _, event := range events {
|
|
b.pending = append(b.pending, event)
|
|
|
|
if len(b.pending) >= b.maxSize {
|
|
// This can't call sendTelemetryBatch directly because we already
|
|
// hold the lock.
|
|
events := b.pending
|
|
b.pending = []*proto.TelemetryEvent{}
|
|
// Resetting the ticker is best effort. We don't care if the ticker
|
|
// has already fired or has a pending message, because the only risk
|
|
// is that we send two telemetry events in short succession (which
|
|
// is totally fine).
|
|
b.ticker.Reset(b.frequency)
|
|
// Perform the send in a goroutine to avoid blocking the DRPC call.
|
|
go b.batchFn(events)
|
|
}
|
|
}
|
|
}
|