mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
Closes #14716 Closes #14717 Adds a new user-scoped tailnet API endpoint (`api/v2/tailnet`) with a new RPC stream for receiving updates on workspaces owned by a specific user, as defined in #14716. When a stream is started, the `WorkspaceUpdatesProvider` will begin listening on the user-scoped pubsub events implemented in #14964. When a relevant event type is seen (such as a workspace state transition), the provider will query the DB for all the workspaces (and agents) owned by the user. This gets compared against the result of the previous query to produce a set of workspace updates. Workspace updates can be requested for any user ID, however only workspaces the authorised user is permitted to `ActionRead` will have their updates streamed. Opening a tunnel to an agent requires that the user can perform `ActionSSH` against the workspace containing it.
407 lines
11 KiB
Go
407 lines
11 KiB
Go
package tailnet
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/yamux"
|
|
"golang.org/x/xerrors"
|
|
"storj.io/drpc/drpcmux"
|
|
"storj.io/drpc/drpcserver"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"github.com/coder/coder/v2/apiversion"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
var ErrUnsupportedVersion = xerrors.New("unsupported version")
|
|
|
|
type streamIDContextKey struct{}
|
|
|
|
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
|
// on the context, since the information is extracted at the HTTP layer for
|
|
// remote clients of the API, or set outside tailnet for local clients (e.g.
|
|
// Coderd's single_tailnet)
|
|
type StreamID struct {
|
|
Name string
|
|
ID uuid.UUID
|
|
Auth CoordinateeAuth
|
|
}
|
|
|
|
func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
|
|
return context.WithValue(ctx, streamIDContextKey{}, streamID)
|
|
}
|
|
|
|
type WorkspaceUpdatesProvider interface {
|
|
io.Closer
|
|
Subscribe(ctx context.Context, userID uuid.UUID) (Subscription, error)
|
|
}
|
|
|
|
type Subscription interface {
|
|
io.Closer
|
|
Updates() <-chan *proto.WorkspaceUpdate
|
|
}
|
|
|
|
type TunnelAuthorizer interface {
|
|
AuthorizeTunnel(ctx context.Context, agentID uuid.UUID) error
|
|
}
|
|
|
|
type ClientServiceOptions struct {
|
|
Logger slog.Logger
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
|
DERPMapUpdateFrequency time.Duration
|
|
DERPMapFn func() *tailcfg.DERPMap
|
|
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
|
|
ResumeTokenProvider ResumeTokenProvider
|
|
WorkspaceUpdatesProvider WorkspaceUpdatesProvider
|
|
}
|
|
|
|
// ClientService is a tailnet coordination service that accepts a connection and version from a
|
|
// tailnet client, and support versions 2.x of the Tailnet API protocol.
|
|
type ClientService struct {
|
|
Logger slog.Logger
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
|
drpc *drpcserver.Server
|
|
}
|
|
|
|
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
|
|
// loaded on each processed connection.
|
|
func NewClientService(options ClientServiceOptions) (
|
|
*ClientService, error,
|
|
) {
|
|
s := &ClientService{Logger: options.Logger, CoordPtr: options.CoordPtr}
|
|
mux := drpcmux.New()
|
|
drpcService := &DRPCService{
|
|
CoordPtr: options.CoordPtr,
|
|
Logger: options.Logger,
|
|
DerpMapUpdateFrequency: options.DERPMapUpdateFrequency,
|
|
DerpMapFn: options.DERPMapFn,
|
|
NetworkTelemetryHandler: options.NetworkTelemetryHandler,
|
|
ResumeTokenProvider: options.ResumeTokenProvider,
|
|
WorkspaceUpdatesProvider: options.WorkspaceUpdatesProvider,
|
|
}
|
|
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("register DRPC service: %w", err)
|
|
}
|
|
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
|
Log: func(err error) {
|
|
if xerrors.Is(err, io.EOF) ||
|
|
xerrors.Is(err, context.Canceled) ||
|
|
xerrors.Is(err, context.DeadlineExceeded) {
|
|
return
|
|
}
|
|
options.Logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
|
},
|
|
})
|
|
s.drpc = server
|
|
return s, nil
|
|
}
|
|
|
|
func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, streamID StreamID) error {
|
|
major, _, err := apiversion.Parse(version)
|
|
if err != nil {
|
|
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
|
|
return err
|
|
}
|
|
switch major {
|
|
case 2:
|
|
return s.ServeConnV2(ctx, conn, streamID)
|
|
default:
|
|
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
|
return ErrUnsupportedVersion
|
|
}
|
|
}
|
|
|
|
func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID StreamID) error {
|
|
config := yamux.DefaultConfig()
|
|
config.LogOutput = io.Discard
|
|
session, err := yamux.Server(conn, config)
|
|
if err != nil {
|
|
return xerrors.Errorf("yamux init failed: %w", err)
|
|
}
|
|
ctx = WithStreamID(ctx, streamID)
|
|
s.Logger.Debug(ctx, "serving dRPC tailnet v2 API session",
|
|
slog.F("peer_id", streamID.ID.String()))
|
|
return s.drpc.Serve(ctx, session)
|
|
}
|
|
|
|
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
|
|
type DRPCService struct {
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
|
Logger slog.Logger
|
|
DerpMapUpdateFrequency time.Duration
|
|
DerpMapFn func() *tailcfg.DERPMap
|
|
NetworkTelemetryHandler func(batch []*proto.TelemetryEvent)
|
|
ResumeTokenProvider ResumeTokenProvider
|
|
WorkspaceUpdatesProvider WorkspaceUpdatesProvider
|
|
}
|
|
|
|
func (s *DRPCService) PostTelemetry(_ context.Context, req *proto.TelemetryRequest) (*proto.TelemetryResponse, error) {
|
|
if s.NetworkTelemetryHandler != nil {
|
|
s.NetworkTelemetryHandler(req.Events)
|
|
}
|
|
return &proto.TelemetryResponse{}, nil
|
|
}
|
|
|
|
func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream proto.DRPCTailnet_StreamDERPMapsStream) error {
|
|
defer stream.Close()
|
|
|
|
ticker := time.NewTicker(s.DerpMapUpdateFrequency)
|
|
defer ticker.Stop()
|
|
|
|
var lastDERPMap *tailcfg.DERPMap
|
|
for {
|
|
derpMap := s.DerpMapFn()
|
|
if derpMap == nil {
|
|
// in testing, we send nil to close the stream.
|
|
return io.EOF
|
|
}
|
|
if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) {
|
|
protoDERPMap := DERPMapToProto(derpMap)
|
|
err := stream.Send(protoDERPMap)
|
|
if err != nil {
|
|
return xerrors.Errorf("send derp map: %w", err)
|
|
}
|
|
lastDERPMap = derpMap
|
|
}
|
|
|
|
ticker.Reset(s.DerpMapUpdateFrequency)
|
|
select {
|
|
case <-stream.Context().Done():
|
|
return nil
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *DRPCService) RefreshResumeToken(ctx context.Context, _ *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error) {
|
|
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
|
|
if !ok {
|
|
return nil, xerrors.New("no Stream ID")
|
|
}
|
|
|
|
res, err := s.ResumeTokenProvider.GenerateResumeToken(ctx, streamID.ID)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("generate resume token: %w", err)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (s *DRPCService) Coordinate(stream proto.DRPCTailnet_CoordinateStream) error {
|
|
ctx := stream.Context()
|
|
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
|
|
if !ok {
|
|
_ = stream.Close()
|
|
return xerrors.New("no Stream ID")
|
|
}
|
|
logger := s.Logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
|
|
logger.Debug(ctx, "starting tailnet Coordinate")
|
|
coord := *(s.CoordPtr.Load())
|
|
reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
|
|
c := communicator{
|
|
logger: logger,
|
|
stream: stream,
|
|
reqs: reqs,
|
|
resps: resps,
|
|
}
|
|
c.communicate()
|
|
return nil
|
|
}
|
|
|
|
func (s *DRPCService) WorkspaceUpdates(req *proto.WorkspaceUpdatesRequest, stream proto.DRPCTailnet_WorkspaceUpdatesStream) error {
|
|
defer stream.Close()
|
|
|
|
ctx := stream.Context()
|
|
|
|
ownerID, err := uuid.FromBytes(req.WorkspaceOwnerId)
|
|
if err != nil {
|
|
return xerrors.Errorf("parse workspace owner ID: %w", err)
|
|
}
|
|
|
|
sub, err := s.WorkspaceUpdatesProvider.Subscribe(ctx, ownerID)
|
|
if err != nil {
|
|
return xerrors.Errorf("subscribe to workspace updates: %w", err)
|
|
}
|
|
defer sub.Close()
|
|
|
|
for {
|
|
select {
|
|
case updates, ok := <-sub.Updates():
|
|
if !ok {
|
|
return nil
|
|
}
|
|
err := stream.Send(updates)
|
|
if err != nil {
|
|
return xerrors.Errorf("send workspace update: %w", err)
|
|
}
|
|
case <-stream.Context().Done():
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
type communicator struct {
|
|
logger slog.Logger
|
|
stream proto.DRPCTailnet_CoordinateStream
|
|
reqs chan<- *proto.CoordinateRequest
|
|
resps <-chan *proto.CoordinateResponse
|
|
}
|
|
|
|
func (c communicator) communicate() {
|
|
go c.loopReq()
|
|
c.loopResp()
|
|
}
|
|
|
|
func (c communicator) loopReq() {
|
|
ctx := c.stream.Context()
|
|
defer close(c.reqs)
|
|
for {
|
|
req, err := c.stream.Recv()
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err))
|
|
return
|
|
}
|
|
err = SendCtx(ctx, c.reqs, req)
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err()))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c communicator) loopResp() {
|
|
ctx := c.stream.Context()
|
|
defer func() {
|
|
err := c.stream.Close()
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err))
|
|
}
|
|
}()
|
|
for {
|
|
resp, err := RecvCtx(ctx, c.resps)
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err))
|
|
return
|
|
}
|
|
err = c.stream.Send(resp)
|
|
if err != nil {
|
|
c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type NetworkTelemetryBatcher struct {
|
|
clock quartz.Clock
|
|
frequency time.Duration
|
|
maxSize int
|
|
batchFn func(batch []*proto.TelemetryEvent)
|
|
|
|
mu sync.Mutex
|
|
closed chan struct{}
|
|
done chan struct{}
|
|
ticker *quartz.Ticker
|
|
pending []*proto.TelemetryEvent
|
|
}
|
|
|
|
func NewNetworkTelemetryBatcher(clk quartz.Clock, frequency time.Duration, maxSize int, batchFn func(batch []*proto.TelemetryEvent)) *NetworkTelemetryBatcher {
|
|
b := &NetworkTelemetryBatcher{
|
|
clock: clk,
|
|
frequency: frequency,
|
|
maxSize: maxSize,
|
|
batchFn: batchFn,
|
|
closed: make(chan struct{}),
|
|
done: make(chan struct{}),
|
|
}
|
|
if b.batchFn == nil {
|
|
b.batchFn = func(batch []*proto.TelemetryEvent) {}
|
|
}
|
|
b.start()
|
|
return b
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) Close() error {
|
|
close(b.closed)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
select {
|
|
case <-ctx.Done():
|
|
return xerrors.New("timed out waiting for batcher to close")
|
|
case <-b.done:
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) sendTelemetryBatch() {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
events := b.pending
|
|
if len(events) == 0 {
|
|
return
|
|
}
|
|
b.pending = []*proto.TelemetryEvent{}
|
|
b.batchFn(events)
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) start() {
|
|
b.ticker = b.clock.NewTicker(b.frequency)
|
|
|
|
go func() {
|
|
defer func() {
|
|
// The lock prevents Handler from racing with Close.
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
close(b.done)
|
|
b.ticker.Stop()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-b.ticker.C:
|
|
b.sendTelemetryBatch()
|
|
b.ticker.Reset(b.frequency)
|
|
case <-b.closed:
|
|
// Send any remaining telemetry events before exiting.
|
|
b.sendTelemetryBatch()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (b *NetworkTelemetryBatcher) Handler(events []*proto.TelemetryEvent) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
select {
|
|
case <-b.closed:
|
|
return
|
|
default:
|
|
}
|
|
|
|
for _, event := range events {
|
|
b.pending = append(b.pending, event)
|
|
|
|
if len(b.pending) >= b.maxSize {
|
|
// This can't call sendTelemetryBatch directly because we already
|
|
// hold the lock.
|
|
events := b.pending
|
|
b.pending = []*proto.TelemetryEvent{}
|
|
// Resetting the ticker is best effort. We don't care if the ticker
|
|
// has already fired or has a pending message, because the only risk
|
|
// is that we send two telemetry events in short succession (which
|
|
// is totally fine).
|
|
b.ticker.Reset(b.frequency)
|
|
// Perform the send in a goroutine to avoid blocking the DRPC call.
|
|
go b.batchFn(events)
|
|
}
|
|
}
|
|
}
|