chore: refactor ServerTailnet to use tailnet.Controllers (#15408)

chore of #14729

Refactors the `ServerTailnet` to use `tailnet.Controller` so that we reuse logic around reconnection and handling control messages, instead of reimplementing.  This unifies our "client" use of the tailscale API across CLI, coderd, and wsproxy.
This commit is contained in:
Spike Curtis
2024-11-08 13:18:56 +04:00
committed by GitHub
parent f7cbf5dd79
commit 8c00ebc6ee
20 changed files with 491 additions and 1240 deletions

View File

@ -30,7 +30,7 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/retry"
"github.com/coder/coder/v2/tailnet/proto"
)
var tailnetTransport *http.Transport
@ -53,9 +53,8 @@ func NewServerTailnet(
ctx context.Context,
logger slog.Logger,
derpServer *derp.Server,
derpMapFn func() *tailcfg.DERPMap,
dialer tailnet.ControlProtocolDialer,
derpForceWebSockets bool,
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
blockEndpoints bool,
traceProvider trace.TracerProvider,
) (*ServerTailnet, error) {
@ -91,46 +90,26 @@ func NewServerTailnet(
})
}
bgRoutines := &sync.WaitGroup{}
originalDerpMap := derpMapFn()
tracer := traceProvider.Tracer(tracing.TracerName)
controller := tailnet.NewController(logger, dialer)
// it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if
// there is an embedded relay, we use the local in-memory dialer.
conn.SetDERPMap(originalDerpMap)
bgRoutines.Add(1)
go func() {
defer bgRoutines.Done()
defer logger.Debug(ctx, "polling DERPMap exited")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-serverCtx.Done():
return
case <-ticker.C:
}
newDerpMap := derpMapFn()
if !tailnet.CompareDERPMaps(originalDerpMap, newDerpMap) {
conn.SetDERPMap(newDerpMap)
originalDerpMap = newDerpMap
}
}
}()
controller.DERPCtrl = tailnet.NewBasicDERPController(logger, conn)
coordCtrl := NewMultiAgentController(serverCtx, logger, tracer, conn)
controller.CoordCtrl = coordCtrl
// TODO: support controller.TelemetryCtrl
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
bgRoutines: bgRoutines,
logger: logger,
tracer: traceProvider.Tracer(tracing.TracerName),
conn: conn,
coordinatee: conn,
getMultiAgent: getMultiAgent,
agentConnectionTimes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
ctx: serverCtx,
cancel: cancel,
logger: logger,
tracer: tracer,
conn: conn,
coordinatee: conn,
controller: controller,
coordCtrl: coordCtrl,
transport: tailnetTransport.Clone(),
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "coder",
Subsystem: "servertailnet",
@ -146,7 +125,7 @@ func NewServerTailnet(
}
tn.transport.DialContext = tn.dialContext
// These options are mostly just picked at random, and they can likely be
// fine tuned further. Generally, users are running applications in dev mode
// fine-tuned further. Generally, users are running applications in dev mode
// which can generate hundreds of requests per page load, so we increased
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
// conns.
@ -164,23 +143,7 @@ func NewServerTailnet(
InsecureSkipVerify: true,
}
agentConn, err := getMultiAgent(ctx)
if err != nil {
return nil, xerrors.Errorf("get initial multi agent: %w", err)
}
tn.agentConn.Store(&agentConn)
// registering the callback also triggers send of the initial node
tn.coordinatee.SetNodeCallback(tn.nodeCallback)
tn.bgRoutines.Add(2)
go func() {
defer tn.bgRoutines.Done()
tn.watchAgentUpdates()
}()
go func() {
defer tn.bgRoutines.Done()
tn.expireOldAgents()
}()
tn.controller.Run(tn.ctx)
return tn, nil
}
@ -190,18 +153,6 @@ func (s *ServerTailnet) Conn() *tailnet.Conn {
return s.conn
}
func (s *ServerTailnet) nodeCallback(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if err != nil {
s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
return
}
err = s.getAgentConn().UpdateSelf(pn)
if err != nil {
s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
}
}
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
s.connsPerAgent.Describe(descs)
s.totalConns.Describe(descs)
@ -212,125 +163,9 @@ func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
s.totalConns.Collect(metrics)
}
func (s *ServerTailnet) expireOldAgents() {
defer s.logger.Debug(s.ctx, "stopped expiring old agents")
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
}
s.doExpireOldAgents(cutoff)
}
}
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
defer span.End()
start := time.Now()
deletedCount := 0
s.nodesMu.Lock()
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentConnectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
err := agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
continue
}
deletedCount++
delete(s.agentConnectionTimes, agentID)
}
}
s.nodesMu.Unlock()
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}
func (s *ServerTailnet) watchAgentUpdates() {
defer s.logger.Debug(s.ctx, "stopped watching agent updates")
for {
conn := s.getAgentConn()
resp, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
s.coordinatee.SetAllPeersLost()
s.reinitCoordinator()
continue
}
return
}
err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
if xerrors.Is(err, tailnet.ErrConnClosed) {
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))
return
}
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
return
}
}
}
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
return *s.agentConn.Load()
}
func (s *ServerTailnet) reinitCoordinator() {
start := time.Now()
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
s.nodesMu.Lock()
agentConn, err := s.getMultiAgent(s.ctx)
if err != nil {
s.nodesMu.Unlock()
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
continue
}
s.agentConn.Store(&agentConn)
// reset the Node callback, which triggers the conn to send the node immediately, and also
// register for updates
s.coordinatee.SetNodeCallback(s.nodeCallback)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentConnectionTimes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.logger.Info(s.ctx, "successfully reinitialized multiagent",
slog.F("agents", len(s.agentConnectionTimes)),
slog.F("took", time.Since(start)),
)
s.nodesMu.Unlock()
return
}
}
type ServerTailnet struct {
ctx context.Context
cancel func()
bgRoutines *sync.WaitGroup
ctx context.Context
cancel func()
logger slog.Logger
tracer trace.Tracer
@ -340,15 +175,8 @@ type ServerTailnet struct {
conn *tailnet.Conn
coordinatee tailnet.Coordinatee
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
nodesMu sync.Mutex
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
// keep a connection to. It contains the last time the agent was connected
// to.
agentConnectionTimes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
controller *tailnet.Controller
coordCtrl *MultiAgentController
transport *http.Transport
@ -446,38 +274,6 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
}, nil
}
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()
_, ok := s.agentConnectionTimes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
err := s.getAgentConn().SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}
s.agentConnectionTimes[agentID] = time.Now()
return nil
}
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
s.nodesMu.Lock()
s.agentTickets[agentID][id] = struct{}{}
s.nodesMu.Unlock()
return func() {
s.nodesMu.Lock()
delete(s.agentTickets[agentID], id)
s.nodesMu.Unlock()
}
}
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
var (
conn *workspacesdk.AgentConn
@ -485,11 +281,11 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*work
)
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
err := s.coordCtrl.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)
ret = s.coordCtrl.acquireTicket(agentID)
conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{
AgentID: agentID,
@ -548,7 +344,8 @@ func (s *ServerTailnet) Close() error {
s.cancel()
_ = s.conn.Close()
s.transport.CloseIdleConnections()
s.bgRoutines.Wait()
s.coordCtrl.Close()
<-s.controller.Closed()
return nil
}
@ -566,3 +363,277 @@ func (c *instrumentedConn) Close() error {
})
return c.Conn.Close()
}
// MultiAgentController is a tailnet.CoordinationController for connecting to multiple workspace
// agents. It keeps track of connection times to the agents, and removes them on a timer if they
// have no active connections and haven't been used in a while.
type MultiAgentController struct {
*tailnet.BasicCoordinationController
logger slog.Logger
tracer trace.Tracer
mu sync.Mutex
// connectionTimes is a map of agents the server wants to keep a connection to. It
// contains the last time the agent was connected to.
connectionTimes map[uuid.UUID]time.Time
// tickets is a map of destinations to a set of connection tickets, representing open
// connections to the destination
tickets map[uuid.UUID]map[uuid.UUID]struct{}
coordination *tailnet.BasicCoordination
cancel context.CancelFunc
expireOldAgentsDone chan struct{}
}
func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.CloserWaiter {
b := m.BasicCoordinationController.NewCoordination(client)
// resync all destinations
m.mu.Lock()
defer m.mu.Unlock()
m.coordination = b
for agentID := range m.connectionTimes {
err := client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
m.logger.Error(context.Background(), "failed to re-add tunnel", slog.F("agent_id", agentID),
slog.Error(err))
b.SendErr(err)
_ = client.Close()
m.coordination = nil
break
}
}
return b
}
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.connectionTimes[agentID]
// If we don't have the agent, subscribe.
if !ok {
m.logger.Debug(context.Background(),
"subscribing to agent", slog.F("agent_id", agentID))
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
err = xerrors.Errorf("subscribe agent: %w", err)
m.coordination.SendErr(err)
_ = m.coordination.Client.Close()
m.coordination = nil
return err
}
}
m.tickets[agentID] = map[uuid.UUID]struct{}{}
}
m.connectionTimes[agentID] = time.Now()
return nil
}
func (m *MultiAgentController) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
m.mu.Lock()
defer m.mu.Unlock()
m.tickets[agentID][id] = struct{}{}
return func() {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.tickets[agentID], id)
}
}
func (m *MultiAgentController) expireOldAgents(ctx context.Context) {
defer close(m.expireOldAgentsDone)
defer m.logger.Debug(context.Background(), "stopped expiring old agents")
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
m.doExpireOldAgents(ctx, cutoff)
}
}
func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := m.tracer.Start(ctx, tracing.FuncName())
defer span.End()
start := time.Now()
deletedCount := 0
m.mu.Lock()
defer m.mu.Unlock()
m.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(m.connectionTimes)))
for agentID, lastConnection := range m.connectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(m.tickets[agentID]) == 0 {
if m.coordination != nil {
err := m.coordination.Client.Send(&proto.CoordinateRequest{
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
})
if err != nil {
m.logger.Debug(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
m.coordination.SendErr(xerrors.Errorf("unsubscribe expired agent: %w", err))
// close the client because we do not want to do a graceful disconnect by
// closing the coordination.
_ = m.coordination.Client.Close()
m.coordination = nil
// Here we continue deleting any inactive agents: there is no point in
// re-establishing tunnels to expired agents when we eventually reconnect.
}
}
deletedCount++
delete(m.connectionTimes, agentID)
}
}
m.logger.Debug(ctx, "pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}
func (m *MultiAgentController) Close() {
m.cancel()
<-m.expireOldAgentsDone
}
func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer trace.Tracer, coordinatee tailnet.Coordinatee) *MultiAgentController {
m := &MultiAgentController{
BasicCoordinationController: &tailnet.BasicCoordinationController{
Logger: logger,
Coordinatee: coordinatee,
SendAcks: false, // we are a client, connecting to multiple agents
},
logger: logger,
tracer: tracer,
connectionTimes: make(map[uuid.UUID]time.Time),
tickets: make(map[uuid.UUID]map[uuid.UUID]struct{}),
expireOldAgentsDone: make(chan struct{}),
}
ctx, m.cancel = context.WithCancel(ctx)
go m.expireOldAgents(ctx)
return m
}
// InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap
// service running in the same memory space.
type InmemTailnetDialer struct {
CoordPtr *atomic.Pointer[tailnet.Coordinator]
DERPFn func() *tailcfg.DERPMap
Logger slog.Logger
ClientID uuid.UUID
}
func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
coord := a.CoordPtr.Load()
if coord == nil {
return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized")
}
coordClient := tailnet.NewInMemoryCoordinatorClient(
a.Logger, a.ClientID, tailnet.SingleTailnetCoordinateeAuth{}, *coord)
derpClient := newPollingDERPClient(a.DERPFn, a.Logger)
return tailnet.ControlProtocolClients{
Closer: closeAll{coord: coordClient, derp: derpClient},
Coordinator: coordClient,
DERP: derpClient,
}, nil
}
func newPollingDERPClient(derpFn func() *tailcfg.DERPMap, logger slog.Logger) tailnet.DERPClient {
ctx, cancel := context.WithCancel(context.Background())
a := &pollingDERPClient{
fn: derpFn,
ctx: ctx,
cancel: cancel,
logger: logger,
ch: make(chan *tailcfg.DERPMap),
loopDone: make(chan struct{}),
}
go a.pollDERP()
return a
}
// pollingDERPClient is a DERP client that just calls a function on a polling
// interval
type pollingDERPClient struct {
fn func() *tailcfg.DERPMap
logger slog.Logger
ctx context.Context
cancel context.CancelFunc
loopDone chan struct{}
lastDERPMap *tailcfg.DERPMap
ch chan *tailcfg.DERPMap
}
// Close the DERP client
func (a *pollingDERPClient) Close() error {
a.cancel()
<-a.loopDone
return nil
}
func (a *pollingDERPClient) Recv() (*tailcfg.DERPMap, error) {
select {
case <-a.ctx.Done():
return nil, a.ctx.Err()
case dm := <-a.ch:
return dm, nil
}
}
func (a *pollingDERPClient) pollDERP() {
defer close(a.loopDone)
defer a.logger.Debug(a.ctx, "polling DERPMap exited")
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-a.ctx.Done():
return
case <-ticker.C:
}
newDerpMap := a.fn()
if !tailnet.CompareDERPMaps(a.lastDERPMap, newDerpMap) {
select {
case <-a.ctx.Done():
return
case a.ch <- newDerpMap:
}
}
}
}
type closeAll struct {
coord tailnet.CoordinatorClient
derp tailnet.DERPClient
}
func (c closeAll) Close() error {
cErr := c.coord.Close()
dErr := c.derp.Close()
if cErr != nil {
return cErr
}
return dErr
}