Files
coder/tailnet/controllers.go
Spike Curtis d7e86278c8 chore: add resume token controller (#15346)
Implements a controller for the Tailnet API resume token RPC, by refactoring from `workspacesdk`.

chore re: #14729
2024-11-07 11:32:20 +04:00

667 lines
18 KiB
Go

package tailnet
import (
"context"
"fmt"
"io"
"math"
"strings"
"sync"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcerr"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
// A Controller connects to the tailnet control plane, and then uses the control protocols to
// program a tailnet.Conn in production (in test it could be an interface simulating the Conn). It
// delegates this task to sub-controllers responsible for the main areas of the tailnet control
// protocol: coordination, DERP map updates, resume tokens, and telemetry.
type Controller struct {
Dialer ControlProtocolDialer
CoordCtrl CoordinationController
DERPCtrl DERPController
ResumeTokenCtrl ResumeTokenController
TelemetryCtrl TelemetryController
}
type CloserWaiter interface {
Close(context.Context) error
Wait() <-chan error
}
// CoordinatorClient is an abstraction of the Coordinator's control protocol interface from the
// perspective of a protocol client (i.e. the Coder Agent is also a client of this interface).
type CoordinatorClient interface {
Close() error
Send(*proto.CoordinateRequest) error
Recv() (*proto.CoordinateResponse, error)
}
// A CoordinationController accepts connections to the control plane, and handles the Coordination
// protocol on behalf of some Coordinatee (tailnet.Conn in production). This is the "glue" code
// between them.
type CoordinationController interface {
New(CoordinatorClient) CloserWaiter
}
// DERPClient is an abstraction of the stream of DERPMap updates from the control plane.
type DERPClient interface {
Close() error
Recv() (*tailcfg.DERPMap, error)
}
// A DERPController accepts connections to the control plane, and handles the DERPMap updates
// delivered over them by programming the data plane (tailnet.Conn or some test interface).
type DERPController interface {
New(DERPClient) CloserWaiter
}
type ResumeTokenClient interface {
RefreshResumeToken(ctx context.Context, in *proto.RefreshResumeTokenRequest) (*proto.RefreshResumeTokenResponse, error)
}
type ResumeTokenController interface {
New(ResumeTokenClient) CloserWaiter
Token() (string, bool)
}
type TelemetryClient interface {
PostTelemetry(ctx context.Context, in *proto.TelemetryRequest) (*proto.TelemetryResponse, error)
}
type TelemetryController interface {
New(TelemetryClient)
}
// ControlProtocolClients represents an abstract interface to the tailnet control plane via a set
// of protocol clients. The Closer should close all the clients (e.g. by closing the underlying
// connection).
type ControlProtocolClients struct {
Closer io.Closer
Coordinator CoordinatorClient
DERP DERPClient
ResumeToken ResumeTokenClient
Telemetry TelemetryClient
}
type ControlProtocolDialer interface {
// Dial connects to the tailnet control plane and returns clients for the different control
// sub-protocols (coordination, DERP maps, resume tokens, and telemetry). If the
// ResumeTokenController is not nil, the dialer should query for a resume token and use it to
// dial, if available.
Dial(ctx context.Context, r ResumeTokenController) (ControlProtocolClients, error)
}
// basicCoordinationController handles the basic coordination operations common to all types of
// tailnet consumers:
//
// 1. sending local node updates to the Coordinator
// 2. receiving peer node updates and programming them into the Coordinatee (e.g. tailnet.Conn)
// 3. (optionally) sending ReadyToHandshake acknowledgements for peer updates.
type basicCoordinationController struct {
logger slog.Logger
coordinatee Coordinatee
sendAcks bool
}
func (c *basicCoordinationController) New(client CoordinatorClient) CloserWaiter {
b := &basicCoordination{
logger: c.logger,
errChan: make(chan error, 1),
coordinatee: c.coordinatee,
client: client,
respLoopDone: make(chan struct{}),
sendAcks: c.sendAcks,
}
c.coordinatee.SetNodeCallback(func(node *Node) {
pn, err := NodeToProto(node)
if err != nil {
b.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
b.sendErr(err)
return
}
b.Lock()
defer b.Unlock()
if b.closed {
b.logger.Debug(context.Background(), "ignored node update because coordination is closed")
return
}
err = b.client.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
if err != nil {
b.sendErr(xerrors.Errorf("write: %w", err))
}
})
go b.respLoop()
return b
}
type basicCoordination struct {
sync.Mutex
closed bool
errChan chan error
coordinatee Coordinatee
logger slog.Logger
client CoordinatorClient
respLoopDone chan struct{}
sendAcks bool
}
func (c *basicCoordination) Close(ctx context.Context) (retErr error) {
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
defer func() {
// We shouldn't just close the protocol right away, because the way dRPC streams work is
// that if you close them, that could take effect immediately, even before the Disconnect
// message is processed. Coordinators are supposed to hang up on us once they get a
// Disconnect message, so we should wait around for that until the context expires.
select {
case <-c.respLoopDone:
c.logger.Debug(ctx, "responses closed after disconnect")
return
case <-ctx.Done():
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
}
// forcefully close the stream
protoErr := c.client.Close()
<-c.respLoopDone
if retErr == nil {
retErr = protoErr
}
}()
err := c.client.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
if err != nil && !xerrors.Is(err, io.EOF) {
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
return xerrors.Errorf("send disconnect: %w", err)
}
c.logger.Debug(context.Background(), "sent disconnect")
return nil
}
func (c *basicCoordination) Wait() <-chan error {
return c.errChan
}
func (c *basicCoordination) sendErr(err error) {
select {
case c.errChan <- err:
default:
}
}
func (c *basicCoordination) respLoop() {
defer func() {
cErr := c.client.Close()
if cErr != nil {
c.logger.Debug(context.Background(), "failed to close coordinate client after respLoop exit", slog.Error(cErr))
}
c.coordinatee.SetAllPeersLost()
close(c.respLoopDone)
}()
for {
resp, err := c.client.Recv()
if err != nil {
c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err))
c.sendErr(xerrors.Errorf("read: %w", err))
return
}
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err))
c.sendErr(xerrors.Errorf("update peers: %w", err))
return
}
// Only send ReadyForHandshake acks from peers without a target.
if c.sendAcks {
// Send an ack back for all received peers. This could
// potentially be smarter to only send an ACK once per client,
// but there's nothing currently stopping clients from reusing
// IDs.
rfh := []*proto.CoordinateRequest_ReadyForHandshake{}
for _, peer := range resp.GetPeerUpdates() {
if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
continue
}
rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id})
}
if len(rfh) > 0 {
err := c.client.Send(&proto.CoordinateRequest{
ReadyForHandshake: rfh,
})
if err != nil {
c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err))
c.sendErr(xerrors.Errorf("send: %w", err))
return
}
}
}
}
}
type singleDestController struct {
*basicCoordinationController
dest uuid.UUID
}
// NewSingleDestController creates a CoordinationController for Coder clients that connect to a
// single tunnel destination, e.g. `coder ssh`, which connects to a single workspace Agent.
func NewSingleDestController(logger slog.Logger, coordinatee Coordinatee, dest uuid.UUID) CoordinationController {
coordinatee.SetTunnelDestination(dest)
return &singleDestController{
basicCoordinationController: &basicCoordinationController{
logger: logger,
coordinatee: coordinatee,
sendAcks: false,
},
dest: dest,
}
}
func (c *singleDestController) New(client CoordinatorClient) CloserWaiter {
// nolint: forcetypeassert
b := c.basicCoordinationController.New(client).(*basicCoordination)
err := client.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: c.dest[:]}})
if err != nil {
b.sendErr(err)
}
return b
}
// NewAgentCoordinationController creates a CoordinationController for Coder Agents, which never
// create tunnels and always send ReadyToHandshake acknowledgements.
func NewAgentCoordinationController(logger slog.Logger, coordinatee Coordinatee) CoordinationController {
return &basicCoordinationController{
logger: logger,
coordinatee: coordinatee,
sendAcks: true,
}
}
type inMemoryCoordClient struct {
sync.Mutex
ctx context.Context
cancel context.CancelFunc
closed bool
logger slog.Logger
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
}
func (c *inMemoryCoordClient) Close() error {
c.cancel()
c.Lock()
defer c.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.reqs)
return nil
}
func (c *inMemoryCoordClient) Send(request *proto.CoordinateRequest) error {
c.Lock()
defer c.Unlock()
if c.closed {
return drpc.ClosedError.New("in-memory coordinator client closed")
}
select {
case c.reqs <- request:
return nil
case <-c.ctx.Done():
return drpc.ClosedError.New("in-memory coordinator client closed")
}
}
func (c *inMemoryCoordClient) Recv() (*proto.CoordinateResponse, error) {
select {
case resp, ok := <-c.resps:
if ok {
return resp, nil
}
// response from Coordinator was closed, so close the send direction as well, so that the
// Coordinator won't be waiting for us while shutting down.
_ = c.Close()
return nil, io.EOF
case <-c.ctx.Done():
return nil, drpc.ClosedError.New("in-memory coord client closed")
}
}
// NewInMemoryCoordinatorClient creates a coordination client that uses channels to connect to a
// local Coordinator. (The typical alternative is a DRPC-based client.)
func NewInMemoryCoordinatorClient(
logger slog.Logger,
clientID, agentID uuid.UUID,
coordinator Coordinator,
) CoordinatorClient {
logger = logger.With(slog.F("agent_id", agentID), slog.F("client_id", clientID))
auth := ClientCoordinateeAuth{AgentID: agentID}
c := &inMemoryCoordClient{logger: logger}
c.ctx, c.cancel = context.WithCancel(context.Background())
// use the background context since we will depend exclusively on closing the req channel to
// tell the coordinator we are done.
c.reqs, c.resps = coordinator.Coordinate(context.Background(),
clientID, fmt.Sprintf("inmemory%s", clientID),
auth,
)
return c
}
type DERPMapSetter interface {
SetDERPMap(derpMap *tailcfg.DERPMap)
}
type basicDERPController struct {
logger slog.Logger
setter DERPMapSetter
}
func (b *basicDERPController) New(client DERPClient) CloserWaiter {
l := &derpSetLoop{
logger: b.logger,
setter: b.setter,
client: client,
errChan: make(chan error, 1),
recvLoopDone: make(chan struct{}),
}
go l.recvLoop()
return l
}
func NewBasicDERPController(logger slog.Logger, setter DERPMapSetter) DERPController {
return &basicDERPController{
logger: logger,
setter: setter,
}
}
type derpSetLoop struct {
logger slog.Logger
setter DERPMapSetter
client DERPClient
sync.Mutex
closed bool
errChan chan error
recvLoopDone chan struct{}
}
func (l *derpSetLoop) Close(ctx context.Context) error {
l.Lock()
defer l.Unlock()
if l.closed {
select {
case <-ctx.Done():
return ctx.Err()
case <-l.recvLoopDone:
return nil
}
}
l.closed = true
cErr := l.client.Close()
select {
case <-ctx.Done():
return ctx.Err()
case <-l.recvLoopDone:
return cErr
}
}
func (l *derpSetLoop) Wait() <-chan error {
return l.errChan
}
func (l *derpSetLoop) recvLoop() {
defer close(l.recvLoopDone)
for {
dm, err := l.client.Recv()
if err != nil {
l.logger.Debug(context.Background(), "failed to receive DERP message", slog.Error(err))
select {
case l.errChan <- err:
default:
}
return
}
l.logger.Debug(context.Background(), "got new DERP Map", slog.F("derp_map", dm))
l.setter.SetDERPMap(dm)
}
}
type BasicTelemetryController struct {
logger slog.Logger
sync.Mutex
client TelemetryClient
unavailable bool
}
func (b *BasicTelemetryController) New(client TelemetryClient) {
b.Lock()
defer b.Unlock()
b.client = client
b.unavailable = false
b.logger.Debug(context.Background(), "new telemetry client connected to controller")
}
func (b *BasicTelemetryController) SendTelemetryEvent(event *proto.TelemetryEvent) {
b.Lock()
if b.client == nil {
b.Unlock()
b.logger.Debug(context.Background(),
"telemetry event dropped; no client", slog.F("event", event))
return
}
if b.unavailable {
b.Unlock()
b.logger.Debug(context.Background(),
"telemetry event dropped; unavailable", slog.F("event", event))
return
}
client := b.client
b.Unlock()
unavailable := sendTelemetry(b.logger, client, event)
if unavailable {
b.Lock()
defer b.Unlock()
if b.client == client {
b.unavailable = true
}
}
}
func NewBasicTelemetryController(logger slog.Logger) *BasicTelemetryController {
return &BasicTelemetryController{logger: logger}
}
var (
_ TelemetrySink = &BasicTelemetryController{}
_ TelemetryController = &BasicTelemetryController{}
)
func sendTelemetry(
logger slog.Logger, client TelemetryClient, event *proto.TelemetryEvent,
) (
unavailable bool,
) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := client.PostTelemetry(ctx, &proto.TelemetryRequest{
Events: []*proto.TelemetryEvent{event},
})
if drpcerr.Code(err) == drpcerr.Unimplemented ||
drpc.ProtocolError.Has(err) &&
strings.Contains(err.Error(), "unknown rpc: ") {
logger.Debug(
context.Background(),
"attempted to send telemetry to a server that doesn't support it",
slog.Error(err),
)
return true
} else if err != nil {
logger.Warn(
context.Background(),
"failed to post telemetry event",
slog.F("event", event), slog.Error(err),
)
}
return false
}
type basicResumeTokenController struct {
logger slog.Logger
sync.Mutex
token *proto.RefreshResumeTokenResponse
refresher *basicResumeTokenRefresher
// for testing
clock quartz.Clock
}
func (b *basicResumeTokenController) New(client ResumeTokenClient) CloserWaiter {
b.Lock()
defer b.Unlock()
if b.refresher != nil {
cErr := b.refresher.Close(context.Background())
if cErr != nil {
b.logger.Debug(context.Background(), "closed previous refresher", slog.Error(cErr))
}
}
b.refresher = newBasicResumeTokenRefresher(b.logger, b.clock, b, client)
return b.refresher
}
func (b *basicResumeTokenController) Token() (string, bool) {
b.Lock()
defer b.Unlock()
if b.token == nil {
return "", false
}
if b.token.ExpiresAt.AsTime().Before(b.clock.Now()) {
return "", false
}
return b.token.Token, true
}
func NewBasicResumeTokenController(logger slog.Logger, clock quartz.Clock) ResumeTokenController {
return &basicResumeTokenController{
logger: logger,
clock: clock,
}
}
type basicResumeTokenRefresher struct {
logger slog.Logger
ctx context.Context
cancel context.CancelFunc
ctrl *basicResumeTokenController
client ResumeTokenClient
errCh chan error
sync.Mutex
closed bool
timer *quartz.Timer
}
func (r *basicResumeTokenRefresher) Close(_ context.Context) error {
r.cancel()
r.Lock()
defer r.Unlock()
if r.closed {
return nil
}
r.closed = true
r.timer.Stop()
select {
case r.errCh <- nil:
default: // already have an error
}
return nil
}
func (r *basicResumeTokenRefresher) Wait() <-chan error {
return r.errCh
}
const never time.Duration = math.MaxInt64
func newBasicResumeTokenRefresher(
logger slog.Logger, clock quartz.Clock,
ctrl *basicResumeTokenController, client ResumeTokenClient,
) *basicResumeTokenRefresher {
r := &basicResumeTokenRefresher{
logger: logger,
ctrl: ctrl,
client: client,
errCh: make(chan error, 1),
}
r.ctx, r.cancel = context.WithCancel(context.Background())
r.timer = clock.AfterFunc(never, r.refresh)
go r.refresh()
return r
}
func (r *basicResumeTokenRefresher) refresh() {
if r.ctx.Err() != nil {
return // context done, no need to refresh
}
res, err := r.client.RefreshResumeToken(r.ctx, &proto.RefreshResumeTokenRequest{})
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
// these can only come from being closed, no need to log
select {
case r.errCh <- nil:
default: // already have an error
}
return
}
if err != nil {
r.logger.Error(r.ctx, "error refreshing coordinator resume token", slog.Error(err))
select {
case r.errCh <- err:
default: // already have an error
}
return
}
r.logger.Debug(r.ctx, "refreshed coordinator resume token",
slog.F("expires_at", res.GetExpiresAt()),
slog.F("refresh_in", res.GetRefreshIn()),
)
r.ctrl.Lock()
if r.ctrl.refresher == r { // don't overwrite if we're not the current refresher
r.ctrl.token = res
} else {
r.logger.Debug(context.Background(), "not writing token because we have a new client")
}
r.ctrl.Unlock()
dur := res.RefreshIn.AsDuration()
if dur <= 0 {
// A sensible delay to refresh again.
dur = 30 * time.Minute
}
r.Lock()
defer r.Unlock()
if r.closed {
return
}
r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh")
}