mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore of #14729 Refactors the `ServerTailnet` to use `tailnet.Controller` so that we reuse logic around reconnection and handling control messages, instead of reimplementing. This unifies our "client" use of the tailscale API across CLI, coderd, and wsproxy.
640 lines
19 KiB
Go
640 lines
19 KiB
Go
package coderd
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/netip"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/xerrors"
|
|
"tailscale.com/derp"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/coderd/workspaceapps"
|
|
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/coder/v2/site"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
)
|
|
|
|
var tailnetTransport *http.Transport
|
|
|
|
func init() {
|
|
tp, valid := http.DefaultTransport.(*http.Transport)
|
|
if !valid {
|
|
panic("dev error: default transport is the wrong type")
|
|
}
|
|
tailnetTransport = tp.Clone()
|
|
// We do not want to respect the proxy settings from the environment, since
|
|
// all network traffic happens over wireguard.
|
|
tailnetTransport.Proxy = nil
|
|
}
|
|
|
|
var _ workspaceapps.AgentProvider = (*ServerTailnet)(nil)
|
|
|
|
// NewServerTailnet creates a new tailnet intended for use by coderd.
|
|
func NewServerTailnet(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
derpServer *derp.Server,
|
|
dialer tailnet.ControlProtocolDialer,
|
|
derpForceWebSockets bool,
|
|
blockEndpoints bool,
|
|
traceProvider trace.TracerProvider,
|
|
) (*ServerTailnet, error) {
|
|
logger = logger.Named("servertailnet")
|
|
conn, err := tailnet.NewConn(&tailnet.Options{
|
|
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()},
|
|
DERPForceWebSockets: derpForceWebSockets,
|
|
Logger: logger,
|
|
BlockEndpoints: blockEndpoints,
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("create tailnet conn: %w", err)
|
|
}
|
|
serverCtx, cancel := context.WithCancel(ctx)
|
|
|
|
// This is set to allow local DERP traffic to be proxied through memory
|
|
// instead of needing to hit the external access URL. Don't use the ctx
|
|
// given in this callback, it's only valid while connecting.
|
|
if derpServer != nil {
|
|
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
|
|
if !region.EmbeddedRelay {
|
|
return nil
|
|
}
|
|
logger.Debug(ctx, "connecting to embedded DERP via in-memory pipe")
|
|
left, right := net.Pipe()
|
|
go func() {
|
|
defer left.Close()
|
|
defer right.Close()
|
|
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
|
|
derpServer.Accept(ctx, right, brw, "internal")
|
|
}()
|
|
return left
|
|
})
|
|
}
|
|
|
|
tracer := traceProvider.Tracer(tracing.TracerName)
|
|
|
|
controller := tailnet.NewController(logger, dialer)
|
|
// it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if
|
|
// there is an embedded relay, we use the local in-memory dialer.
|
|
controller.DERPCtrl = tailnet.NewBasicDERPController(logger, conn)
|
|
coordCtrl := NewMultiAgentController(serverCtx, logger, tracer, conn)
|
|
controller.CoordCtrl = coordCtrl
|
|
// TODO: support controller.TelemetryCtrl
|
|
|
|
tn := &ServerTailnet{
|
|
ctx: serverCtx,
|
|
cancel: cancel,
|
|
logger: logger,
|
|
tracer: tracer,
|
|
conn: conn,
|
|
coordinatee: conn,
|
|
controller: controller,
|
|
coordCtrl: coordCtrl,
|
|
transport: tailnetTransport.Clone(),
|
|
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
|
Namespace: "coder",
|
|
Subsystem: "servertailnet",
|
|
Name: "open_connections",
|
|
Help: "Total number of TCP connections currently open to workspace agents.",
|
|
}, []string{"network"}),
|
|
totalConns: prometheus.NewCounterVec(prometheus.CounterOpts{
|
|
Namespace: "coder",
|
|
Subsystem: "servertailnet",
|
|
Name: "connections_total",
|
|
Help: "Total number of TCP connections made to workspace agents.",
|
|
}, []string{"network"}),
|
|
}
|
|
tn.transport.DialContext = tn.dialContext
|
|
// These options are mostly just picked at random, and they can likely be
|
|
// fine-tuned further. Generally, users are running applications in dev mode
|
|
// which can generate hundreds of requests per page load, so we increased
|
|
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
|
|
// conns.
|
|
tn.transport.MaxIdleConnsPerHost = 6
|
|
tn.transport.MaxIdleConns = 0
|
|
tn.transport.IdleConnTimeout = 10 * time.Minute
|
|
// We intentionally don't verify the certificate chain here.
|
|
// The connection to the workspace is already established and most
|
|
// apps are already going to be accessed over plain HTTP, this config
|
|
// simply allows apps being run over HTTPS to be accessed without error --
|
|
// many of which may be using self-signed certs.
|
|
tn.transport.TLSClientConfig = &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
//nolint:gosec
|
|
InsecureSkipVerify: true,
|
|
}
|
|
|
|
tn.controller.Run(tn.ctx)
|
|
return tn, nil
|
|
}
|
|
|
|
// Conn is used to access the underlying tailnet conn of the ServerTailnet. It
|
|
// should only be used for read-only purposes.
|
|
func (s *ServerTailnet) Conn() *tailnet.Conn {
|
|
return s.conn
|
|
}
|
|
|
|
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
|
|
s.connsPerAgent.Describe(descs)
|
|
s.totalConns.Describe(descs)
|
|
}
|
|
|
|
func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
|
|
s.connsPerAgent.Collect(metrics)
|
|
s.totalConns.Collect(metrics)
|
|
}
|
|
|
|
type ServerTailnet struct {
|
|
ctx context.Context
|
|
cancel func()
|
|
|
|
logger slog.Logger
|
|
tracer trace.Tracer
|
|
|
|
// in prod, these are the same, but coordinatee is a subset of Conn's
|
|
// methods which makes some tests easier.
|
|
conn *tailnet.Conn
|
|
coordinatee tailnet.Coordinatee
|
|
|
|
controller *tailnet.Controller
|
|
coordCtrl *MultiAgentController
|
|
|
|
transport *http.Transport
|
|
|
|
connsPerAgent *prometheus.GaugeVec
|
|
totalConns *prometheus.CounterVec
|
|
}
|
|
|
|
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHostname string) *httputil.ReverseProxy {
|
|
// Rewrite the targetURL's Host to point to the agent's IP. This is
|
|
// necessary because due to TCP connection caching, each agent needs to be
|
|
// addressed invidivually. Otherwise, all connections get dialed as
|
|
// "localhost:port", causing connections to be shared across agents.
|
|
tgt := *targetURL
|
|
_, port, _ := net.SplitHostPort(tgt.Host)
|
|
tgt.Host = net.JoinHostPort(tailnet.TailscaleServicePrefix.AddrFromUUID(agentID).String(), port)
|
|
|
|
proxy := httputil.NewSingleHostReverseProxy(&tgt)
|
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, theErr error) {
|
|
var (
|
|
desc = "Failed to proxy request to application: " + theErr.Error()
|
|
additionalInfo = ""
|
|
additionalButtonLink = ""
|
|
additionalButtonText = ""
|
|
)
|
|
|
|
var tlsError tls.RecordHeaderError
|
|
if (errors.As(theErr, &tlsError) && tlsError.Msg == "first record does not look like a TLS handshake") ||
|
|
errors.Is(theErr, http.ErrSchemeMismatch) {
|
|
// If the error is due to an HTTP/HTTPS mismatch, we can provide a
|
|
// more helpful error message with redirect buttons.
|
|
switchURL := url.URL{
|
|
Scheme: dashboardURL.Scheme,
|
|
}
|
|
_, protocol, isPort := app.PortInfo()
|
|
if isPort {
|
|
targetProtocol := "https"
|
|
if protocol == "https" {
|
|
targetProtocol = "http"
|
|
}
|
|
app = app.ChangePortProtocol(targetProtocol)
|
|
|
|
switchURL.Host = fmt.Sprintf("%s%s", app.String(), strings.TrimPrefix(wildcardHostname, "*"))
|
|
additionalButtonLink = switchURL.String()
|
|
additionalButtonText = fmt.Sprintf("Switch to %s", strings.ToUpper(targetProtocol))
|
|
additionalInfo += fmt.Sprintf("This error seems to be due to an app protocol mismatch, try switching to %s.", strings.ToUpper(targetProtocol))
|
|
}
|
|
}
|
|
|
|
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
|
|
Status: http.StatusBadGateway,
|
|
Title: "Bad Gateway",
|
|
Description: desc,
|
|
RetryEnabled: true,
|
|
DashboardURL: dashboardURL.String(),
|
|
AdditionalInfo: additionalInfo,
|
|
AdditionalButtonLink: additionalButtonLink,
|
|
AdditionalButtonText: additionalButtonText,
|
|
})
|
|
}
|
|
proxy.Director = s.director(agentID, proxy.Director)
|
|
proxy.Transport = s.transport
|
|
|
|
return proxy
|
|
}
|
|
|
|
type agentIDKey struct{}
|
|
|
|
// director makes sure agentIDKey is set on the context in the reverse proxy.
|
|
// This allows the transport to correctly identify which agent to dial to.
|
|
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
|
|
return func(req *http.Request) {
|
|
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
|
|
*req = *req.WithContext(ctx)
|
|
prev(req)
|
|
}
|
|
}
|
|
|
|
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
|
|
if !ok {
|
|
return nil, xerrors.Errorf("no agent id attached")
|
|
}
|
|
|
|
nc, err := s.DialAgentNetConn(ctx, agentID, network, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.connsPerAgent.WithLabelValues("tcp").Inc()
|
|
s.totalConns.WithLabelValues("tcp").Inc()
|
|
return &instrumentedConn{
|
|
Conn: nc,
|
|
agentID: agentID,
|
|
connsPerAgent: s.connsPerAgent,
|
|
}, nil
|
|
}
|
|
|
|
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
|
|
var (
|
|
conn *workspacesdk.AgentConn
|
|
ret func()
|
|
)
|
|
|
|
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
|
err := s.coordCtrl.ensureAgent(agentID)
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
|
|
}
|
|
ret = s.coordCtrl.acquireTicket(agentID)
|
|
|
|
conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{
|
|
AgentID: agentID,
|
|
CloseFunc: func() error { return workspacesdk.ErrSkipClose },
|
|
})
|
|
|
|
// Since we now have an open conn, be careful to close it if we error
|
|
// without returning it to the user.
|
|
|
|
reachable := conn.AwaitReachable(ctx)
|
|
if !reachable {
|
|
ret()
|
|
return nil, nil, xerrors.New("agent is unreachable")
|
|
}
|
|
|
|
return conn, ret, nil
|
|
}
|
|
|
|
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
|
|
conn, release, err := s.AgentConn(ctx, agentID)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("acquire agent conn: %w", err)
|
|
}
|
|
|
|
// Since we now have an open conn, be careful to close it if we error
|
|
// without returning it to the user.
|
|
|
|
nc, err := conn.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
release()
|
|
return nil, xerrors.Errorf("dial context: %w", err)
|
|
}
|
|
|
|
return &netConnCloser{Conn: nc, close: func() {
|
|
release()
|
|
}}, err
|
|
}
|
|
|
|
func (s *ServerTailnet) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
|
|
s.conn.MagicsockServeHTTPDebug(w, r)
|
|
}
|
|
|
|
type netConnCloser struct {
|
|
net.Conn
|
|
close func()
|
|
}
|
|
|
|
func (c *netConnCloser) Close() error {
|
|
c.close()
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func (s *ServerTailnet) Close() error {
|
|
s.logger.Info(s.ctx, "closing server tailnet")
|
|
defer s.logger.Debug(s.ctx, "server tailnet close complete")
|
|
s.cancel()
|
|
_ = s.conn.Close()
|
|
s.transport.CloseIdleConnections()
|
|
s.coordCtrl.Close()
|
|
<-s.controller.Closed()
|
|
return nil
|
|
}
|
|
|
|
type instrumentedConn struct {
|
|
net.Conn
|
|
|
|
agentID uuid.UUID
|
|
closeOnce sync.Once
|
|
connsPerAgent *prometheus.GaugeVec
|
|
}
|
|
|
|
func (c *instrumentedConn) Close() error {
|
|
c.closeOnce.Do(func() {
|
|
c.connsPerAgent.WithLabelValues("tcp").Dec()
|
|
})
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
// MultiAgentController is a tailnet.CoordinationController for connecting to multiple workspace
|
|
// agents. It keeps track of connection times to the agents, and removes them on a timer if they
|
|
// have no active connections and haven't been used in a while.
|
|
type MultiAgentController struct {
|
|
*tailnet.BasicCoordinationController
|
|
|
|
logger slog.Logger
|
|
tracer trace.Tracer
|
|
|
|
mu sync.Mutex
|
|
// connectionTimes is a map of agents the server wants to keep a connection to. It
|
|
// contains the last time the agent was connected to.
|
|
connectionTimes map[uuid.UUID]time.Time
|
|
// tickets is a map of destinations to a set of connection tickets, representing open
|
|
// connections to the destination
|
|
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
|
coordination *tailnet.BasicCoordination
|
|
|
|
cancel context.CancelFunc
|
|
expireOldAgentsDone chan struct{}
|
|
}
|
|
|
|
func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.CloserWaiter {
|
|
b := m.BasicCoordinationController.NewCoordination(client)
|
|
// resync all destinations
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.coordination = b
|
|
for agentID := range m.connectionTimes {
|
|
err := client.Send(&proto.CoordinateRequest{
|
|
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
|
})
|
|
if err != nil {
|
|
m.logger.Error(context.Background(), "failed to re-add tunnel", slog.F("agent_id", agentID),
|
|
slog.Error(err))
|
|
b.SendErr(err)
|
|
_ = client.Close()
|
|
m.coordination = nil
|
|
break
|
|
}
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
_, ok := m.connectionTimes[agentID]
|
|
// If we don't have the agent, subscribe.
|
|
if !ok {
|
|
m.logger.Debug(context.Background(),
|
|
"subscribing to agent", slog.F("agent_id", agentID))
|
|
if m.coordination != nil {
|
|
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
|
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
|
})
|
|
if err != nil {
|
|
err = xerrors.Errorf("subscribe agent: %w", err)
|
|
m.coordination.SendErr(err)
|
|
_ = m.coordination.Client.Close()
|
|
m.coordination = nil
|
|
return err
|
|
}
|
|
}
|
|
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
|
}
|
|
m.connectionTimes[agentID] = time.Now()
|
|
return nil
|
|
}
|
|
|
|
func (m *MultiAgentController) acquireTicket(agentID uuid.UUID) (release func()) {
|
|
id := uuid.New()
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.tickets[agentID][id] = struct{}{}
|
|
|
|
return func() {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
delete(m.tickets[agentID], id)
|
|
}
|
|
}
|
|
|
|
func (m *MultiAgentController) expireOldAgents(ctx context.Context) {
|
|
defer close(m.expireOldAgentsDone)
|
|
defer m.logger.Debug(context.Background(), "stopped expiring old agents")
|
|
const (
|
|
tick = 5 * time.Minute
|
|
cutoff = 30 * time.Minute
|
|
)
|
|
|
|
ticker := time.NewTicker(tick)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
m.doExpireOldAgents(ctx, cutoff)
|
|
}
|
|
}
|
|
|
|
func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff time.Duration) {
|
|
// TODO: add some attrs to this.
|
|
ctx, span := m.tracer.Start(ctx, tracing.FuncName())
|
|
defer span.End()
|
|
|
|
start := time.Now()
|
|
deletedCount := 0
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(m.connectionTimes)))
|
|
for agentID, lastConnection := range m.connectionTimes {
|
|
// If no one has connected since the cutoff and there are no active
|
|
// connections, remove the agent.
|
|
if time.Since(lastConnection) > cutoff && len(m.tickets[agentID]) == 0 {
|
|
if m.coordination != nil {
|
|
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
|
RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
|
})
|
|
if err != nil {
|
|
m.logger.Debug(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
|
m.coordination.SendErr(xerrors.Errorf("unsubscribe expired agent: %w", err))
|
|
// close the client because we do not want to do a graceful disconnect by
|
|
// closing the coordination.
|
|
_ = m.coordination.Client.Close()
|
|
m.coordination = nil
|
|
// Here we continue deleting any inactive agents: there is no point in
|
|
// re-establishing tunnels to expired agents when we eventually reconnect.
|
|
}
|
|
}
|
|
deletedCount++
|
|
delete(m.connectionTimes, agentID)
|
|
}
|
|
}
|
|
m.logger.Debug(ctx, "pruned inactive agents",
|
|
slog.F("deleted", deletedCount),
|
|
slog.F("took", time.Since(start)),
|
|
)
|
|
}
|
|
|
|
func (m *MultiAgentController) Close() {
|
|
m.cancel()
|
|
<-m.expireOldAgentsDone
|
|
}
|
|
|
|
func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer trace.Tracer, coordinatee tailnet.Coordinatee) *MultiAgentController {
|
|
m := &MultiAgentController{
|
|
BasicCoordinationController: &tailnet.BasicCoordinationController{
|
|
Logger: logger,
|
|
Coordinatee: coordinatee,
|
|
SendAcks: false, // we are a client, connecting to multiple agents
|
|
},
|
|
logger: logger,
|
|
tracer: tracer,
|
|
connectionTimes: make(map[uuid.UUID]time.Time),
|
|
tickets: make(map[uuid.UUID]map[uuid.UUID]struct{}),
|
|
expireOldAgentsDone: make(chan struct{}),
|
|
}
|
|
ctx, m.cancel = context.WithCancel(ctx)
|
|
go m.expireOldAgents(ctx)
|
|
return m
|
|
}
|
|
|
|
// InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap
|
|
// service running in the same memory space.
|
|
type InmemTailnetDialer struct {
|
|
CoordPtr *atomic.Pointer[tailnet.Coordinator]
|
|
DERPFn func() *tailcfg.DERPMap
|
|
Logger slog.Logger
|
|
ClientID uuid.UUID
|
|
}
|
|
|
|
func (a *InmemTailnetDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
|
coord := a.CoordPtr.Load()
|
|
if coord == nil {
|
|
return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized")
|
|
}
|
|
coordClient := tailnet.NewInMemoryCoordinatorClient(
|
|
a.Logger, a.ClientID, tailnet.SingleTailnetCoordinateeAuth{}, *coord)
|
|
derpClient := newPollingDERPClient(a.DERPFn, a.Logger)
|
|
return tailnet.ControlProtocolClients{
|
|
Closer: closeAll{coord: coordClient, derp: derpClient},
|
|
Coordinator: coordClient,
|
|
DERP: derpClient,
|
|
}, nil
|
|
}
|
|
|
|
func newPollingDERPClient(derpFn func() *tailcfg.DERPMap, logger slog.Logger) tailnet.DERPClient {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
a := &pollingDERPClient{
|
|
fn: derpFn,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
logger: logger,
|
|
ch: make(chan *tailcfg.DERPMap),
|
|
loopDone: make(chan struct{}),
|
|
}
|
|
go a.pollDERP()
|
|
return a
|
|
}
|
|
|
|
// pollingDERPClient is a DERP client that just calls a function on a polling
|
|
// interval
|
|
type pollingDERPClient struct {
|
|
fn func() *tailcfg.DERPMap
|
|
logger slog.Logger
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
loopDone chan struct{}
|
|
lastDERPMap *tailcfg.DERPMap
|
|
ch chan *tailcfg.DERPMap
|
|
}
|
|
|
|
// Close the DERP client
|
|
func (a *pollingDERPClient) Close() error {
|
|
a.cancel()
|
|
<-a.loopDone
|
|
return nil
|
|
}
|
|
|
|
func (a *pollingDERPClient) Recv() (*tailcfg.DERPMap, error) {
|
|
select {
|
|
case <-a.ctx.Done():
|
|
return nil, a.ctx.Err()
|
|
case dm := <-a.ch:
|
|
return dm, nil
|
|
}
|
|
}
|
|
|
|
func (a *pollingDERPClient) pollDERP() {
|
|
defer close(a.loopDone)
|
|
defer a.logger.Debug(a.ctx, "polling DERPMap exited")
|
|
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-a.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
newDerpMap := a.fn()
|
|
if !tailnet.CompareDERPMaps(a.lastDERPMap, newDerpMap) {
|
|
select {
|
|
case <-a.ctx.Done():
|
|
return
|
|
case a.ch <- newDerpMap:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type closeAll struct {
|
|
coord tailnet.CoordinatorClient
|
|
derp tailnet.DERPClient
|
|
}
|
|
|
|
func (c closeAll) Close() error {
|
|
cErr := c.coord.Close()
|
|
dErr := c.derp.Close()
|
|
if cErr != nil {
|
|
return cErr
|
|
}
|
|
return dErr
|
|
}
|