mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore: refactor ServerTailnet to use tailnet.Controllers (#15408)
chore of #14729 Refactors the `ServerTailnet` to use `tailnet.Controller` so that we reuse logic around reconnection and handling control messages, instead of reimplementing. This unifies our "client" use of the tailscale API across CLI, coderd, and wsproxy.
This commit is contained in:
@ -30,7 +30,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/site"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
var tailnetTransport *http.Transport
|
||||
@ -53,9 +53,8 @@ func NewServerTailnet(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
derpServer *derp.Server,
|
||||
derpMapFn func() *tailcfg.DERPMap,
|
||||
dialer tailnet.ControlProtocolDialer,
|
||||
derpForceWebSockets bool,
|
||||
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
|
||||
blockEndpoints bool,
|
||||
traceProvider trace.TracerProvider,
|
||||
) (*ServerTailnet, error) {
|
||||
@ -91,46 +90,26 @@ func NewServerTailnet(
|
||||
})
|
||||
}
|
||||
|
||||
bgRoutines := &sync.WaitGroup{}
|
||||
originalDerpMap := derpMapFn()
|
||||
tracer := traceProvider.Tracer(tracing.TracerName)
|
||||
|
||||
controller := tailnet.NewController(logger, dialer)
|
||||
// it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if
|
||||
// there is an embedded relay, we use the local in-memory dialer.
|
||||
conn.SetDERPMap(originalDerpMap)
|
||||
bgRoutines.Add(1)
|
||||
go func() {
|
||||
defer bgRoutines.Done()
|
||||
defer logger.Debug(ctx, "polling DERPMap exited")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-serverCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
newDerpMap := derpMapFn()
|
||||
if !tailnet.CompareDERPMaps(originalDerpMap, newDerpMap) {
|
||||
conn.SetDERPMap(newDerpMap)
|
||||
originalDerpMap = newDerpMap
|
||||
}
|
||||
}
|
||||
}()
|
||||
controller.DERPCtrl = tailnet.NewBasicDERPController(logger, conn)
|
||||
coordCtrl := NewMultiAgentController(serverCtx, logger, tracer, conn)
|
||||
controller.CoordCtrl = coordCtrl
|
||||
// TODO: support controller.TelemetryCtrl
|
||||
|
||||
tn := &ServerTailnet{
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
bgRoutines: bgRoutines,
|
||||
logger: logger,
|
||||
tracer: traceProvider.Tracer(tracing.TracerName),
|
||||
conn: conn,
|
||||
coordinatee: conn,
|
||||
getMultiAgent: getMultiAgent,
|
||||
agentConnectionTimes: map[uuid.UUID]time.Time{},
|
||||
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
transport: tailnetTransport.Clone(),
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
conn: conn,
|
||||
coordinatee: conn,
|
||||
controller: controller,
|
||||
coordCtrl: coordCtrl,
|
||||
transport: tailnetTransport.Clone(),
|
||||
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "servertailnet",
|
||||
@ -146,7 +125,7 @@ func NewServerTailnet(
|
||||
}
|
||||
tn.transport.DialContext = tn.dialContext
|
||||
// These options are mostly just picked at random, and they can likely be
|
||||
// fine tuned further. Generally, users are running applications in dev mode
|
||||
// fine-tuned further. Generally, users are running applications in dev mode
|
||||
// which can generate hundreds of requests per page load, so we increased
|
||||
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
|
||||
// conns.
|
||||
@ -164,23 +143,7 @@ func NewServerTailnet(
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
|
||||
agentConn, err := getMultiAgent(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get initial multi agent: %w", err)
|
||||
}
|
||||
tn.agentConn.Store(&agentConn)
|
||||
// registering the callback also triggers send of the initial node
|
||||
tn.coordinatee.SetNodeCallback(tn.nodeCallback)
|
||||
|
||||
tn.bgRoutines.Add(2)
|
||||
go func() {
|
||||
defer tn.bgRoutines.Done()
|
||||
tn.watchAgentUpdates()
|
||||
}()
|
||||
go func() {
|
||||
defer tn.bgRoutines.Done()
|
||||
tn.expireOldAgents()
|
||||
}()
|
||||
tn.controller.Run(tn.ctx)
|
||||
return tn, nil
|
||||
}
|
||||
|
||||
@ -190,18 +153,6 @@ func (s *ServerTailnet) Conn() *tailnet.Conn {
|
||||
return s.conn
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) nodeCallback(node *tailnet.Node) {
|
||||
pn, err := tailnet.NodeToProto(node)
|
||||
if err != nil {
|
||||
s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
|
||||
return
|
||||
}
|
||||
err = s.getAgentConn().UpdateSelf(pn)
|
||||
if err != nil {
|
||||
s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
|
||||
s.connsPerAgent.Describe(descs)
|
||||
s.totalConns.Describe(descs)
|
||||
@ -212,125 +163,9 @@ func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
|
||||
s.totalConns.Collect(metrics)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) expireOldAgents() {
|
||||
defer s.logger.Debug(s.ctx, "stopped expiring old agents")
|
||||
const (
|
||||
tick = 5 * time.Minute
|
||||
cutoff = 30 * time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.doExpireOldAgents(cutoff)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
|
||||
// TODO: add some attrs to this.
|
||||
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
start := time.Now()
|
||||
deletedCount := 0
|
||||
|
||||
s.nodesMu.Lock()
|
||||
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
|
||||
agentConn := s.getAgentConn()
|
||||
for agentID, lastConnection := range s.agentConnectionTimes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
err := agentConn.UnsubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
continue
|
||||
}
|
||||
deletedCount++
|
||||
delete(s.agentConnectionTimes, agentID)
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
|
||||
slog.F("deleted", deletedCount),
|
||||
slog.F("took", time.Since(start)),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) watchAgentUpdates() {
|
||||
defer s.logger.Debug(s.ctx, "stopped watching agent updates")
|
||||
for {
|
||||
conn := s.getAgentConn()
|
||||
resp, ok := conn.NextUpdate(s.ctx)
|
||||
if !ok {
|
||||
if conn.IsClosed() && s.ctx.Err() == nil {
|
||||
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
|
||||
s.coordinatee.SetAllPeersLost()
|
||||
s.reinitCoordinator()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
if xerrors.Is(err, tailnet.ErrConnClosed) {
|
||||
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))
|
||||
return
|
||||
}
|
||||
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
|
||||
return *s.agentConn.Load()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) reinitCoordinator() {
|
||||
start := time.Now()
|
||||
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
|
||||
s.nodesMu.Lock()
|
||||
agentConn, err := s.getMultiAgent(s.ctx)
|
||||
if err != nil {
|
||||
s.nodesMu.Unlock()
|
||||
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
s.agentConn.Store(&agentConn)
|
||||
// reset the Node callback, which triggers the conn to send the node immediately, and also
|
||||
// register for updates
|
||||
s.coordinatee.SetNodeCallback(s.nodeCallback)
|
||||
|
||||
// Resubscribe to all of the agents we're tracking.
|
||||
for agentID := range s.agentConnectionTimes {
|
||||
err := agentConn.SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info(s.ctx, "successfully reinitialized multiagent",
|
||||
slog.F("agents", len(s.agentConnectionTimes)),
|
||||
slog.F("took", time.Since(start)),
|
||||
)
|
||||
s.nodesMu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type ServerTailnet struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
bgRoutines *sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
|
||||
logger slog.Logger
|
||||
tracer trace.Tracer
|
||||
@ -340,15 +175,8 @@ type ServerTailnet struct {
|
||||
conn *tailnet.Conn
|
||||
coordinatee tailnet.Coordinatee
|
||||
|
||||
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
|
||||
agentConn atomic.Pointer[tailnet.MultiAgentConn]
|
||||
nodesMu sync.Mutex
|
||||
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
|
||||
// keep a connection to. It contains the last time the agent was connected
|
||||
// to.
|
||||
agentConnectionTimes map[uuid.UUID]time.Time
|
||||
// agentTockets holds a map of all open connections to an agent.
|
||||
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
controller *tailnet.Controller
|
||||
coordCtrl *MultiAgentController
|
||||
|
||||
transport *http.Transport
|
||||
|
||||
@ -446,38 +274,6 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
|
||||
s.nodesMu.Lock()
|
||||
defer s.nodesMu.Unlock()
|
||||
|
||||
_, ok := s.agentConnectionTimes[agentID]
|
||||
// If we don't have the node, subscribe.
|
||||
if !ok {
|
||||
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
|
||||
err := s.getAgentConn().SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe agent: %w", err)
|
||||
}
|
||||
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
|
||||
s.agentConnectionTimes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
|
||||
id := uuid.New()
|
||||
s.nodesMu.Lock()
|
||||
s.agentTickets[agentID][id] = struct{}{}
|
||||
s.nodesMu.Unlock()
|
||||
|
||||
return func() {
|
||||
s.nodesMu.Lock()
|
||||
delete(s.agentTickets[agentID], id)
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
|
||||
var (
|
||||
conn *workspacesdk.AgentConn
|
||||
@ -485,11 +281,11 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*work
|
||||
)
|
||||
|
||||
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
||||
err := s.ensureAgent(agentID)
|
||||
err := s.coordCtrl.ensureAgent(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
|
||||
}
|
||||
ret = s.acquireTicket(agentID)
|
||||
ret = s.coordCtrl.acquireTicket(agentID)
|
||||
|
||||
conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{
|
||||
AgentID: agentID,
|
||||
@ -548,7 +344,8 @@ func (s *ServerTailnet) Close() error {
|
||||
s.cancel()
|
||||
_ = s.conn.Close()
|
||||
s.transport.CloseIdleConnections()
|
||||
s.bgRoutines.Wait()
|
||||
s.coordCtrl.Close()
|
||||
<-s.controller.Closed()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -566,3 +363,277 @@ func (c *instrumentedConn) Close() error {
|
||||
})
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// MultiAgentController is a tailnet.CoordinationController for connecting to multiple workspace
|
||||
// agents. It keeps track of connection times to the agents, and removes them on a timer if they
|
||||
// have no active connections and haven't been used in a while.
|
||||
type MultiAgentController struct {
|
||||
*tailnet.BasicCoordinationController
|
||||
|
||||
logger slog.Logger
|
||||
tracer trace.Tracer
|
||||
|
||||
mu sync.Mutex
|
||||
// connectionTimes is a map of agents the server wants to keep a connection to. It
|
||||
// contains the last time the agent was connected to.
|
||||
connectionTimes map[uuid.UUID]time.Time
|
||||
// tickets is a map of destinations to a set of connection tickets, representing open
|
||||
// connections to the destination
|
||||
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
coordination *tailnet.BasicCoordination
|
||||
|
||||
cancel context.CancelFunc
|
||||
expireOldAgentsDone chan struct{}
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.CloserWaiter {
|
||||
b := m.BasicCoordinationController.NewCoordination(client)
|
||||
// resync all destinations
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.coordination = b
|
||||
for agentID := range m.connectionTimes {
|
||||
err := client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Error(context.Background(), "failed to re-add tunnel", slog.F("agent_id", agentID),
|
||||
slog.Error(err))
|
||||
b.SendErr(err)
|
||||
_ = client.Close()
|
||||
m.coordination = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.connectionTimes[agentID]
|
||||
// If we don't have the agent, subscribe.
|
||||
if !ok {
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("subscribe agent: %w", err)
|
||||
m.coordination.SendErr(err)
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) acquireTicket(agentID uuid.UUID) (release func()) {
|
||||
id := uuid.New()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.tickets[agentID][id] = struct{}{}
|
||||
|
||||
return func() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.tickets[agentID], id)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) expireOldAgents(ctx context.Context) {
|
||||
defer close(m.expireOldAgentsDone)
|
||||
defer m.logger.Debug(context.Background(), "stopped expiring old agents")
|
||||
const (
|
||||
tick = 5 * time.Minute
|
||||
cutoff = 30 * time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
m.doExpireOldAgents(ctx, cutoff)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff time.Duration) {
|
||||
// TODO: add some attrs to this.
|
||||
ctx, span := m.tracer.Start(ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
start := time.Now()
|
||||
deletedCount := 0
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(m.connectionTimes)))
|
||||
for agentID, lastConnection := range m.connectionTimes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(m.tickets[agentID]) == 0 {
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Debug(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
m.coordination.SendErr(xerrors.Errorf("unsubscribe expired agent: %w", err))
|
||||
// close the client because we do not want to do a graceful disconnect by
|
||||
// closing the coordination.
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
// Here we continue deleting any inactive agents: there is no point in
|
||||
// re-establishing tunnels to expired agents when we eventually reconnect.
|
||||
}
|
||||
}
|
||||
deletedCount++
|
||||
delete(m.connectionTimes, agentID)
|
||||
}
|
||||
}
|
||||
m.logger.Debug(ctx, "pruned inactive agents",
|
||||
slog.F("deleted", deletedCount),
|
||||
slog.F("took", time.Since(start)),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *MultiAgentController) Close() {
|
||||
m.cancel()
|
||||
<-m.expireOldAgentsDone
|
||||
}
|
||||
|
||||
func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer trace.Tracer, coordinatee tailnet.Coordinatee) *MultiAgentController {
|
||||
m := &MultiAgentController{
|
||||
BasicCoordinationController: &tailnet.BasicCoordinationController{
|
||||
Logger: logger,
|
||||
Coordinatee: coordinatee,
|
||||
SendAcks: false, // we are a client, connecting to multiple agents
|
||||
},
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
connectionTimes: make(map[uuid.UUID]time.Time),
|
||||
tickets: make(map[uuid.UUID]map[uuid.UUID]struct{}),
|
||||
expireOldAgentsDone: make(chan struct{}),
|
||||
}
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
go m.expireOldAgents(ctx)
|
||||
return m
|
||||
}
|
||||
|
||||
// InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap
|
||||
// service running in the same memory space.
|
||||
type InmemTailnetDialer struct {
|
||||
CoordPtr *atomic.Pointer[tailnet.Coordinator]
|
||||
DERPFn func() *tailcfg.DERPMap
|
||||
Logger slog.Logger
|
||||
ClientID uuid.UUID
|
||||
}
|
||||
|
||||
func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
||||
coord := a.CoordPtr.Load()
|
||||
if coord == nil {
|
||||
return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized")
|
||||
}
|
||||
coordClient := tailnet.NewInMemoryCoordinatorClient(
|
||||
a.Logger, a.ClientID, tailnet.SingleTailnetCoordinateeAuth{}, *coord)
|
||||
derpClient := newPollingDERPClient(a.DERPFn, a.Logger)
|
||||
return tailnet.ControlProtocolClients{
|
||||
Closer: closeAll{coord: coordClient, derp: derpClient},
|
||||
Coordinator: coordClient,
|
||||
DERP: derpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newPollingDERPClient(derpFn func() *tailcfg.DERPMap, logger slog.Logger) tailnet.DERPClient {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
a := &pollingDERPClient{
|
||||
fn: derpFn,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
ch: make(chan *tailcfg.DERPMap),
|
||||
loopDone: make(chan struct{}),
|
||||
}
|
||||
go a.pollDERP()
|
||||
return a
|
||||
}
|
||||
|
||||
// pollingDERPClient is a DERP client that just calls a function on a polling
|
||||
// interval
|
||||
type pollingDERPClient struct {
|
||||
fn func() *tailcfg.DERPMap
|
||||
logger slog.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
loopDone chan struct{}
|
||||
lastDERPMap *tailcfg.DERPMap
|
||||
ch chan *tailcfg.DERPMap
|
||||
}
|
||||
|
||||
// Close the DERP client
|
||||
func (a *pollingDERPClient) Close() error {
|
||||
a.cancel()
|
||||
<-a.loopDone
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *pollingDERPClient) Recv() (*tailcfg.DERPMap, error) {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return nil, a.ctx.Err()
|
||||
case dm := <-a.ch:
|
||||
return dm, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *pollingDERPClient) pollDERP() {
|
||||
defer close(a.loopDone)
|
||||
defer a.logger.Debug(a.ctx, "polling DERPMap exited")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
newDerpMap := a.fn()
|
||||
if !tailnet.CompareDERPMaps(a.lastDERPMap, newDerpMap) {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return
|
||||
case a.ch <- newDerpMap:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type closeAll struct {
|
||||
coord tailnet.CoordinatorClient
|
||||
derp tailnet.DERPClient
|
||||
}
|
||||
|
||||
func (c closeAll) Close() error {
|
||||
cErr := c.coord.Close()
|
||||
dErr := c.derp.Close()
|
||||
if cErr != nil {
|
||||
return cErr
|
||||
}
|
||||
return dErr
|
||||
}
|
||||
|
Reference in New Issue
Block a user