package coderd import ( "bufio" "context" "crypto/tls" "errors" "fmt" "net" "net/http" "net/http/httputil" "net/netip" "net/url" "strings" "sync" "sync/atomic" "time" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" "tailscale.com/derp" "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/coderd/workspaceapps" "github.com/coder/coder/v2/coderd/workspaceapps/appurl" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/site" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" ) var tailnetTransport *http.Transport func init() { tp, valid := http.DefaultTransport.(*http.Transport) if !valid { panic("dev error: default transport is the wrong type") } tailnetTransport = tp.Clone() // We do not want to respect the proxy settings from the environment, since // all network traffic happens over wireguard. tailnetTransport.Proxy = nil } var _ workspaceapps.AgentProvider = (*ServerTailnet)(nil) // NewServerTailnet creates a new tailnet intended for use by coderd. func NewServerTailnet( ctx context.Context, logger slog.Logger, derpServer *derp.Server, dialer tailnet.ControlProtocolDialer, derpForceWebSockets bool, blockEndpoints bool, traceProvider trace.TracerProvider, ) (*ServerTailnet, error) { logger = logger.Named("servertailnet") conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()}, DERPForceWebSockets: derpForceWebSockets, Logger: logger, BlockEndpoints: blockEndpoints, }) if err != nil { return nil, xerrors.Errorf("create tailnet conn: %w", err) } serverCtx, cancel := context.WithCancel(ctx) // This is set to allow local DERP traffic to be proxied through memory // instead of needing to hit the external access URL. Don't use the ctx // given in this callback, it's only valid while connecting. if derpServer != nil { conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { // Don't set up the embedded relay if we're shutting down if !region.EmbeddedRelay || ctx.Err() != nil { return nil } logger.Debug(ctx, "connecting to embedded DERP via in-memory pipe") left, right := net.Pipe() go func() { defer left.Close() defer right.Close() brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) derpServer.Accept(ctx, right, brw, "internal") }() return left }) } tracer := traceProvider.Tracer(tracing.TracerName) controller := tailnet.NewController(logger, dialer) // it's important to set the DERPRegionDialer above _before_ we set the DERP map so that if // there is an embedded relay, we use the local in-memory dialer. controller.DERPCtrl = tailnet.NewBasicDERPController(logger, conn) coordCtrl := NewMultiAgentController(serverCtx, logger, tracer, conn) controller.CoordCtrl = coordCtrl // TODO: support controller.TelemetryCtrl tn := &ServerTailnet{ ctx: serverCtx, cancel: cancel, logger: logger, tracer: tracer, conn: conn, coordinatee: conn, controller: controller, coordCtrl: coordCtrl, transport: tailnetTransport.Clone(), connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coder", Subsystem: "servertailnet", Name: "open_connections", Help: "Total number of TCP connections currently open to workspace agents.", }, []string{"network"}), totalConns: prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "coder", Subsystem: "servertailnet", Name: "connections_total", Help: "Total number of TCP connections made to workspace agents.", }, []string{"network"}), } tn.transport.DialContext = tn.dialContext // These options are mostly just picked at random, and they can likely be // fine-tuned further. Generally, users are running applications in dev mode // which can generate hundreds of requests per page load, so we increased // MaxIdleConnsPerHost from 2 to 6 and removed the limit of total idle // conns. tn.transport.MaxIdleConnsPerHost = 6 tn.transport.MaxIdleConns = 0 tn.transport.IdleConnTimeout = 10 * time.Minute // We intentionally don't verify the certificate chain here. // The connection to the workspace is already established and most // apps are already going to be accessed over plain HTTP, this config // simply allows apps being run over HTTPS to be accessed without error -- // many of which may be using self-signed certs. tn.transport.TLSClientConfig = &tls.Config{ MinVersion: tls.VersionTLS12, //nolint:gosec InsecureSkipVerify: true, } tn.controller.Run(tn.ctx) return tn, nil } // Conn is used to access the underlying tailnet conn of the ServerTailnet. It // should only be used for read-only purposes. func (s *ServerTailnet) Conn() *tailnet.Conn { return s.conn } func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) { s.connsPerAgent.Describe(descs) s.totalConns.Describe(descs) } func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) { s.connsPerAgent.Collect(metrics) s.totalConns.Collect(metrics) } type ServerTailnet struct { ctx context.Context cancel func() logger slog.Logger tracer trace.Tracer // in prod, these are the same, but coordinatee is a subset of Conn's // methods which makes some tests easier. conn *tailnet.Conn coordinatee tailnet.Coordinatee controller *tailnet.Controller coordCtrl *MultiAgentController transport *http.Transport connsPerAgent *prometheus.GaugeVec totalConns *prometheus.CounterVec } func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHostname string) *httputil.ReverseProxy { // Rewrite the targetURL's Host to point to the agent's IP. This is // necessary because due to TCP connection caching, each agent needs to be // addressed invidivually. Otherwise, all connections get dialed as // "localhost:port", causing connections to be shared across agents. tgt := *targetURL _, port, _ := net.SplitHostPort(tgt.Host) tgt.Host = net.JoinHostPort(tailnet.TailscaleServicePrefix.AddrFromUUID(agentID).String(), port) proxy := httputil.NewSingleHostReverseProxy(&tgt) proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, theErr error) { var ( desc = "Failed to proxy request to application: " + theErr.Error() additionalInfo = "" additionalButtonLink = "" additionalButtonText = "" ) var tlsError tls.RecordHeaderError if (errors.As(theErr, &tlsError) && tlsError.Msg == "first record does not look like a TLS handshake") || errors.Is(theErr, http.ErrSchemeMismatch) { // If the error is due to an HTTP/HTTPS mismatch, we can provide a // more helpful error message with redirect buttons. switchURL := url.URL{ Scheme: dashboardURL.Scheme, } _, protocol, isPort := app.PortInfo() if isPort { targetProtocol := "https" if protocol == "https" { targetProtocol = "http" } app = app.ChangePortProtocol(targetProtocol) switchURL.Host = fmt.Sprintf("%s%s", app.String(), strings.TrimPrefix(wildcardHostname, "*")) additionalButtonLink = switchURL.String() additionalButtonText = fmt.Sprintf("Switch to %s", strings.ToUpper(targetProtocol)) additionalInfo += fmt.Sprintf("This error seems to be due to an app protocol mismatch, try switching to %s.", strings.ToUpper(targetProtocol)) } } site.RenderStaticErrorPage(w, r, site.ErrorPageData{ Status: http.StatusBadGateway, Title: "Bad Gateway", Description: desc, RetryEnabled: true, DashboardURL: dashboardURL.String(), AdditionalInfo: additionalInfo, AdditionalButtonLink: additionalButtonLink, AdditionalButtonText: additionalButtonText, }) } proxy.Director = s.director(agentID, proxy.Director) proxy.Transport = s.transport return proxy } type agentIDKey struct{} // director makes sure agentIDKey is set on the context in the reverse proxy. // This allows the transport to correctly identify which agent to dial to. func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) { return func(req *http.Request) { ctx := context.WithValue(req.Context(), agentIDKey{}, agentID) *req = *req.WithContext(ctx) prev(req) } } func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID) if !ok { return nil, xerrors.Errorf("no agent id attached") } nc, err := s.DialAgentNetConn(ctx, agentID, network, addr) if err != nil { return nil, err } s.connsPerAgent.WithLabelValues("tcp").Inc() s.totalConns.WithLabelValues("tcp").Inc() return &instrumentedConn{ Conn: nc, agentID: agentID, connsPerAgent: s.connsPerAgent, }, nil } func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) { var ( conn *workspacesdk.AgentConn ret func() ) s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID)) err := s.coordCtrl.ensureAgent(agentID) if err != nil { return nil, nil, xerrors.Errorf("ensure agent: %w", err) } ret = s.coordCtrl.acquireTicket(agentID) conn = workspacesdk.NewAgentConn(s.conn, workspacesdk.AgentConnOptions{ AgentID: agentID, CloseFunc: func() error { return workspacesdk.ErrSkipClose }, }) // Since we now have an open conn, be careful to close it if we error // without returning it to the user. reachable := conn.AwaitReachable(ctx) if !reachable { ret() return nil, nil, xerrors.New("agent is unreachable") } return conn, ret, nil } func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) { conn, release, err := s.AgentConn(ctx, agentID) if err != nil { return nil, xerrors.Errorf("acquire agent conn: %w", err) } // Since we now have an open conn, be careful to close it if we error // without returning it to the user. nc, err := conn.DialContext(ctx, network, addr) if err != nil { release() return nil, xerrors.Errorf("dial context: %w", err) } return &netConnCloser{Conn: nc, close: func() { release() }}, err } func (s *ServerTailnet) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { s.conn.MagicsockServeHTTPDebug(w, r) } type netConnCloser struct { net.Conn close func() } func (c *netConnCloser) Close() error { c.close() return c.Conn.Close() } func (s *ServerTailnet) Close() error { s.logger.Info(s.ctx, "closing server tailnet") defer s.logger.Debug(s.ctx, "server tailnet close complete") s.cancel() _ = s.conn.Close() s.transport.CloseIdleConnections() s.coordCtrl.Close() <-s.controller.Closed() return nil } type instrumentedConn struct { net.Conn agentID uuid.UUID closeOnce sync.Once connsPerAgent *prometheus.GaugeVec } func (c *instrumentedConn) Close() error { c.closeOnce.Do(func() { c.connsPerAgent.WithLabelValues("tcp").Dec() }) return c.Conn.Close() } // MultiAgentController is a tailnet.CoordinationController for connecting to multiple workspace // agents. It keeps track of connection times to the agents, and removes them on a timer if they // have no active connections and haven't been used in a while. type MultiAgentController struct { *tailnet.BasicCoordinationController logger slog.Logger tracer trace.Tracer mu sync.Mutex // connectionTimes is a map of agents the server wants to keep a connection to. It // contains the last time the agent was connected to. connectionTimes map[uuid.UUID]time.Time // tickets is a map of destinations to a set of connection tickets, representing open // connections to the destination tickets map[uuid.UUID]map[uuid.UUID]struct{} coordination *tailnet.BasicCoordination cancel context.CancelFunc expireOldAgentsDone chan struct{} } func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.CloserWaiter { b := m.BasicCoordinationController.NewCoordination(client) // resync all destinations m.mu.Lock() defer m.mu.Unlock() m.coordination = b for agentID := range m.connectionTimes { err := client.Send(&proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}, }) if err != nil { m.logger.Error(context.Background(), "failed to re-add tunnel", slog.F("agent_id", agentID), slog.Error(err)) b.SendErr(err) _ = client.Close() m.coordination = nil break } } return b } func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error { m.mu.Lock() defer m.mu.Unlock() _, ok := m.connectionTimes[agentID] // If we don't have the agent, subscribe. if !ok { m.logger.Debug(context.Background(), "subscribing to agent", slog.F("agent_id", agentID)) if m.coordination != nil { err := m.coordination.Client.Send(&proto.CoordinateRequest{ AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}, }) if err != nil { err = xerrors.Errorf("subscribe agent: %w", err) m.coordination.SendErr(err) _ = m.coordination.Client.Close() m.coordination = nil return err } } m.tickets[agentID] = map[uuid.UUID]struct{}{} } m.connectionTimes[agentID] = time.Now() return nil } func (m *MultiAgentController) acquireTicket(agentID uuid.UUID) (release func()) { id := uuid.New() m.mu.Lock() defer m.mu.Unlock() m.tickets[agentID][id] = struct{}{} return func() { m.mu.Lock() defer m.mu.Unlock() delete(m.tickets[agentID], id) } } func (m *MultiAgentController) expireOldAgents(ctx context.Context) { defer close(m.expireOldAgentsDone) defer m.logger.Debug(context.Background(), "stopped expiring old agents") const ( tick = 5 * time.Minute cutoff = 30 * time.Minute ) ticker := time.NewTicker(tick) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: } m.doExpireOldAgents(ctx, cutoff) } } func (m *MultiAgentController) doExpireOldAgents(ctx context.Context, cutoff time.Duration) { // TODO: add some attrs to this. ctx, span := m.tracer.Start(ctx, tracing.FuncName()) defer span.End() start := time.Now() deletedCount := 0 m.mu.Lock() defer m.mu.Unlock() m.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(m.connectionTimes))) for agentID, lastConnection := range m.connectionTimes { // If no one has connected since the cutoff and there are no active // connections, remove the agent. if time.Since(lastConnection) > cutoff && len(m.tickets[agentID]) == 0 { if m.coordination != nil { err := m.coordination.Client.Send(&proto.CoordinateRequest{ RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}, }) if err != nil { m.logger.Debug(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID)) m.coordination.SendErr(xerrors.Errorf("unsubscribe expired agent: %w", err)) // close the client because we do not want to do a graceful disconnect by // closing the coordination. _ = m.coordination.Client.Close() m.coordination = nil // Here we continue deleting any inactive agents: there is no point in // re-establishing tunnels to expired agents when we eventually reconnect. } } deletedCount++ delete(m.connectionTimes, agentID) } } m.logger.Debug(ctx, "pruned inactive agents", slog.F("deleted", deletedCount), slog.F("took", time.Since(start)), ) } func (m *MultiAgentController) Close() { m.cancel() <-m.expireOldAgentsDone } func NewMultiAgentController(ctx context.Context, logger slog.Logger, tracer trace.Tracer, coordinatee tailnet.Coordinatee) *MultiAgentController { m := &MultiAgentController{ BasicCoordinationController: &tailnet.BasicCoordinationController{ Logger: logger, Coordinatee: coordinatee, SendAcks: false, // we are a client, connecting to multiple agents }, logger: logger, tracer: tracer, connectionTimes: make(map[uuid.UUID]time.Time), tickets: make(map[uuid.UUID]map[uuid.UUID]struct{}), expireOldAgentsDone: make(chan struct{}), } ctx, m.cancel = context.WithCancel(ctx) go m.expireOldAgents(ctx) return m } type Pinger interface { Ping(context.Context) (time.Duration, error) } // InmemTailnetDialer is a tailnet.ControlProtocolDialer that connects to a Coordinator and DERPMap // service running in the same memory space. type InmemTailnetDialer struct { CoordPtr *atomic.Pointer[tailnet.Coordinator] DERPFn func() *tailcfg.DERPMap Logger slog.Logger ClientID uuid.UUID // DatabaseHealthCheck is used to validate that the store is reachable. DatabaseHealthCheck Pinger } func (a *InmemTailnetDialer) Dial(ctx context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) { if a.DatabaseHealthCheck != nil { if _, err := a.DatabaseHealthCheck.Ping(ctx); err != nil { return tailnet.ControlProtocolClients{}, xerrors.Errorf("%w: %v", codersdk.ErrDatabaseNotReachable, err) } } coord := a.CoordPtr.Load() if coord == nil { return tailnet.ControlProtocolClients{}, xerrors.Errorf("tailnet coordinator not initialized") } coordClient := tailnet.NewInMemoryCoordinatorClient( a.Logger, a.ClientID, tailnet.SingleTailnetCoordinateeAuth{}, *coord) derpClient := newPollingDERPClient(a.DERPFn, a.Logger) return tailnet.ControlProtocolClients{ Closer: closeAll{coord: coordClient, derp: derpClient}, Coordinator: coordClient, DERP: derpClient, }, nil } func newPollingDERPClient(derpFn func() *tailcfg.DERPMap, logger slog.Logger) tailnet.DERPClient { ctx, cancel := context.WithCancel(context.Background()) a := &pollingDERPClient{ fn: derpFn, ctx: ctx, cancel: cancel, logger: logger, ch: make(chan *tailcfg.DERPMap), loopDone: make(chan struct{}), } go a.pollDERP() return a } // pollingDERPClient is a DERP client that just calls a function on a polling // interval type pollingDERPClient struct { fn func() *tailcfg.DERPMap logger slog.Logger ctx context.Context cancel context.CancelFunc loopDone chan struct{} lastDERPMap *tailcfg.DERPMap ch chan *tailcfg.DERPMap } // Close the DERP client func (a *pollingDERPClient) Close() error { a.cancel() <-a.loopDone return nil } func (a *pollingDERPClient) Recv() (*tailcfg.DERPMap, error) { select { case <-a.ctx.Done(): return nil, a.ctx.Err() case dm := <-a.ch: return dm, nil } } func (a *pollingDERPClient) pollDERP() { defer close(a.loopDone) defer a.logger.Debug(a.ctx, "polling DERPMap exited") ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-a.ctx.Done(): return case <-ticker.C: } newDerpMap := a.fn() if !tailnet.CompareDERPMaps(a.lastDERPMap, newDerpMap) { select { case <-a.ctx.Done(): return case a.ch <- newDerpMap: } } } } type closeAll struct { coord tailnet.CoordinatorClient derp tailnet.DERPClient } func (c closeAll) Close() error { cErr := c.coord.Close() dErr := c.derp.Close() if cErr != nil { return cErr } return dErr }