Files
coder/coderd/tailnet.go
Spike Curtis 04991f425a fix: set node callback each time we reinit the coordinator in servertailnet (#12140)
I think this will resolve #12136 but lets get a proper test at the system level before closing.

Before this change, we only register the node callback at start of day for the server tailnet.  If the coordinator changes, like we know happens when we are licensed for the PGCoordinator, we close the connection to the old coord, and open a new one to the new coord.

The callback is designed to direct the updates to the new coordinator, but there is nothing that specifically triggers it to fire after we connect to the new coordinator.

If we have STUN, then period re-STUNs will generally get it to fire eventually, but without STUN it we could go indefinitely without a callback.

This PR changes the servertailnet to re-register the callback each time we reconnect to the coordinator.  Registering a callback (even if it's the same callback) triggers an immediate call with our node information, so the new coordinator will have it.
2024-02-14 20:45:31 +04:00

505 lines
14 KiB
Go

package coderd
import (
"bufio"
"context"
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/retry"
)
var tailnetTransport *http.Transport
func init() {
var valid bool
tailnetTransport, valid = http.DefaultTransport.(*http.Transport)
if !valid {
panic("dev error: default transport is the wrong type")
}
}
var _ workspaceapps.AgentProvider = (*ServerTailnet)(nil)
// NewServerTailnet creates a new tailnet intended for use by coderd.
func NewServerTailnet(
ctx context.Context,
logger slog.Logger,
derpServer *derp.Server,
derpMapFn func() *tailcfg.DERPMap,
derpForceWebSockets bool,
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
traceProvider trace.TracerProvider,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
originalDerpMap := derpMapFn()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: originalDerpMap,
DERPForceWebSockets: derpForceWebSockets,
Logger: logger,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
serverCtx, cancel := context.WithCancel(ctx)
derpMapUpdaterClosed := make(chan struct{})
go func() {
defer close(derpMapUpdaterClosed)
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-serverCtx.Done():
return
case <-ticker.C:
}
newDerpMap := derpMapFn()
if !tailnet.CompareDERPMaps(originalDerpMap, newDerpMap) {
conn.SetDERPMap(newDerpMap)
originalDerpMap = newDerpMap
}
}
}()
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
derpMapUpdaterClosed: derpMapUpdaterClosed,
logger: logger,
tracer: traceProvider.Tracer(tracing.TracerName),
conn: conn,
coordinatee: conn,
getMultiAgent: getMultiAgent,
agentConnectionTimes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "coder",
Subsystem: "servertailnet",
Name: "open_connections",
Help: "Total number of TCP connections currently open to workspace agents.",
}, []string{"network"}),
totalConns: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "servertailnet",
Name: "connections_total",
Help: "Total number of TCP connections made to workspace agents.",
}, []string{"network"}),
}
tn.transport.DialContext = tn.dialContext
// These options are mostly just picked at random, and they can likely be
// fine tuned further. Generally, users are running applications in dev mode
// which can generate hundreds of requests per page load, so we increased
// MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle
// conns.
tn.transport.MaxIdleConnsPerHost = 6
tn.transport.MaxIdleConns = 0
tn.transport.IdleConnTimeout = 10 * time.Minute
// We intentionally don't verify the certificate chain here.
// The connection to the workspace is already established and most
// apps are already going to be accessed over plain HTTP, this config
// simply allows apps being run over HTTPS to be accessed without error --
// many of which may be using self-signed certs.
tn.transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
//nolint:gosec
InsecureSkipVerify: true,
}
agentConn, err := getMultiAgent(ctx)
if err != nil {
return nil, xerrors.Errorf("get initial multi agent: %w", err)
}
tn.agentConn.Store(&agentConn)
// registering the callback also triggers send of the initial node
tn.coordinatee.SetNodeCallback(tn.nodeCallback)
// This is set to allow local DERP traffic to be proxied through memory
// instead of needing to hit the external access URL. Don't use the ctx
// given in this callback, it's only valid while connecting.
if derpServer != nil {
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
}
go tn.watchAgentUpdates()
go tn.expireOldAgents()
return tn, nil
}
func (s *ServerTailnet) nodeCallback(node *tailnet.Node) {
pn, err := tailnet.NodeToProto(node)
if err != nil {
s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
return
}
err = s.getAgentConn().UpdateSelf(pn)
if err != nil {
s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
}
}
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
s.connsPerAgent.Describe(descs)
s.totalConns.Describe(descs)
}
func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
s.connsPerAgent.Collect(metrics)
s.totalConns.Collect(metrics)
}
func (s *ServerTailnet) expireOldAgents() {
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
}
s.doExpireOldAgents(cutoff)
}
}
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
defer span.End()
start := time.Now()
deletedCount := 0
s.nodesMu.Lock()
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentConnectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
err := agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
continue
}
deletedCount++
delete(s.agentConnectionTimes, agentID)
}
}
s.nodesMu.Unlock()
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}
func (s *ServerTailnet) watchAgentUpdates() {
for {
conn := s.getAgentConn()
resp, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.logger.Warn(s.ctx, "multiagent closed, reinitializing")
s.coordinatee.SetAllPeersLost()
s.reinitCoordinator()
continue
}
return
}
err := s.coordinatee.UpdatePeers(resp.GetPeerUpdates())
if err != nil {
if xerrors.Is(err, tailnet.ErrConnClosed) {
s.logger.Warn(context.Background(), "tailnet conn closed, exiting watchAgentUpdates", slog.Error(err))
return
}
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
return
}
}
}
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
return *s.agentConn.Load()
}
func (s *ServerTailnet) reinitCoordinator() {
start := time.Now()
for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); {
s.nodesMu.Lock()
agentConn, err := s.getMultiAgent(s.ctx)
if err != nil {
s.nodesMu.Unlock()
s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err))
continue
}
s.agentConn.Store(&agentConn)
// reset the Node callback, which triggers the conn to send the node immediately, and also
// register for updates
s.coordinatee.SetNodeCallback(s.nodeCallback)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentConnectionTimes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.logger.Info(s.ctx, "successfully reinitialized multiagent",
slog.F("agents", len(s.agentConnectionTimes)),
slog.F("took", time.Since(start)),
)
s.nodesMu.Unlock()
return
}
}
type ServerTailnet struct {
ctx context.Context
cancel func()
derpMapUpdaterClosed chan struct{}
logger slog.Logger
tracer trace.Tracer
// in prod, these are the same, but coordinatee is a subset of Conn's
// methods which makes some tests easier.
conn *tailnet.Conn
coordinatee tailnet.Coordinatee
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
nodesMu sync.Mutex
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
// keep a connection to. It contains the last time the agent was connected
// to.
agentConnectionTimes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
transport *http.Transport
connsPerAgent *prometheus.GaugeVec
totalConns *prometheus.CounterVec
}
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy {
// Rewrite the targetURL's Host to point to the agent's IP. This is
// necessary because due to TCP connection caching, each agent needs to be
// addressed invidivually. Otherwise, all connections get dialed as
// "localhost:port", causing connections to be shared across agents.
tgt := *targetURL
_, port, _ := net.SplitHostPort(tgt.Host)
tgt.Host = net.JoinHostPort(tailnet.IPFromUUID(agentID).String(), port)
proxy := httputil.NewSingleHostReverseProxy(&tgt)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: dashboardURL.String(),
})
}
proxy.Director = s.director(agentID, proxy.Director)
proxy.Transport = s.transport
return proxy
}
type agentIDKey struct{}
// director makes sure agentIDKey is set on the context in the reverse proxy.
// This allows the transport to correctly identify which agent to dial to.
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
return func(req *http.Request) {
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
*req = *req.WithContext(ctx)
prev(req)
}
}
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
if !ok {
return nil, xerrors.Errorf("no agent id attached")
}
nc, err := s.DialAgentNetConn(ctx, agentID, network, addr)
if err != nil {
return nil, err
}
s.connsPerAgent.WithLabelValues("tcp").Inc()
s.totalConns.WithLabelValues("tcp").Inc()
return &instrumentedConn{
Conn: nc,
agentID: agentID,
connsPerAgent: s.connsPerAgent,
}, nil
}
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()
_, ok := s.agentConnectionTimes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
err := s.getAgentConn().SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}
s.agentConnectionTimes[agentID] = time.Now()
return nil
}
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
s.nodesMu.Lock()
s.agentTickets[agentID][id] = struct{}{}
s.nodesMu.Unlock()
return func() {
s.nodesMu.Lock()
delete(s.agentTickets[agentID], id)
s.nodesMu.Unlock()
}
}
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
var (
conn *codersdk.WorkspaceAgentConn
ret func()
)
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return codersdk.ErrSkipClose },
})
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
reachable := conn.AwaitReachable(ctx)
if !reachable {
ret()
return nil, nil, xerrors.New("agent is unreachable")
}
return conn, ret, nil
}
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
conn, release, err := s.AgentConn(ctx, agentID)
if err != nil {
return nil, xerrors.Errorf("acquire agent conn: %w", err)
}
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
nc, err := conn.DialContext(ctx, network, addr)
if err != nil {
release()
return nil, xerrors.Errorf("dial context: %w", err)
}
return &netConnCloser{Conn: nc, close: func() {
release()
}}, err
}
func (s *ServerTailnet) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
s.conn.MagicsockServeHTTPDebug(w, r)
}
type netConnCloser struct {
net.Conn
close func()
}
func (c *netConnCloser) Close() error {
c.close()
return c.Conn.Close()
}
func (s *ServerTailnet) Close() error {
s.cancel()
_ = s.conn.Close()
s.transport.CloseIdleConnections()
<-s.derpMapUpdaterClosed
return nil
}
type instrumentedConn struct {
net.Conn
agentID uuid.UUID
closeOnce sync.Once
connsPerAgent *prometheus.GaugeVec
}
func (c *instrumentedConn) Close() error {
c.closeOnce.Do(func() {
c.connsPerAgent.WithLabelValues("tcp").Dec()
})
return c.Conn.Close()
}