mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore: replace wsconncache with a single tailnet (#8176)
This commit is contained in:
339
coderd/tailnet.go
Normal file
339
coderd/tailnet.go
Normal file
@ -0,0 +1,339 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
var tailnetTransport *http.Transport
|
||||
|
||||
func init() {
|
||||
var valid bool
|
||||
tailnetTransport, valid = http.DefaultTransport.(*http.Transport)
|
||||
if !valid {
|
||||
panic("dev error: default transport is the wrong type")
|
||||
}
|
||||
}
|
||||
|
||||
// NewServerTailnet creates a new tailnet intended for use by coderd. It
|
||||
// automatically falls back to wsconncache if a legacy agent is encountered.
|
||||
func NewServerTailnet(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
derpServer *derp.Server,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
coord *atomic.Pointer[tailnet.Coordinator],
|
||||
cache *wsconncache.Cache,
|
||||
) (*ServerTailnet, error) {
|
||||
logger = logger.Named("servertailnet")
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: derpMap,
|
||||
Logger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create tailnet conn: %w", err)
|
||||
}
|
||||
|
||||
serverCtx, cancel := context.WithCancel(ctx)
|
||||
tn := &ServerTailnet{
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
conn: conn,
|
||||
coord: coord,
|
||||
cache: cache,
|
||||
agentNodes: map[uuid.UUID]time.Time{},
|
||||
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
transport: tailnetTransport.Clone(),
|
||||
}
|
||||
tn.transport.DialContext = tn.dialContext
|
||||
tn.transport.MaxIdleConnsPerHost = 10
|
||||
tn.transport.MaxIdleConns = 0
|
||||
agentConn := (*coord.Load()).ServeMultiAgent(uuid.New())
|
||||
tn.agentConn.Store(&agentConn)
|
||||
|
||||
err = tn.getAgentConn().UpdateSelf(conn.Node())
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
|
||||
}
|
||||
conn.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := tn.getAgentConn().UpdateSelf(node)
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
|
||||
}
|
||||
})
|
||||
|
||||
// This is set to allow local DERP traffic to be proxied through memory
|
||||
// instead of needing to hit the external access URL. Don't use the ctx
|
||||
// given in this callback, it's only valid while connecting.
|
||||
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
|
||||
if !region.EmbeddedRelay {
|
||||
return nil
|
||||
}
|
||||
left, right := net.Pipe()
|
||||
go func() {
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
|
||||
derpServer.Accept(ctx, right, brw, "internal")
|
||||
}()
|
||||
return left
|
||||
})
|
||||
|
||||
go tn.watchAgentUpdates()
|
||||
go tn.expireOldAgents()
|
||||
return tn, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) expireOldAgents() {
|
||||
const (
|
||||
tick = 5 * time.Minute
|
||||
cutoff = 30 * time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.nodesMu.Lock()
|
||||
agentConn := s.getAgentConn()
|
||||
for agentID, lastConnection := range s.agentNodes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
_ = agentConn
|
||||
// err := agentConn.UnsubscribeAgent(agentID)
|
||||
// if err != nil {
|
||||
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
// }
|
||||
// delete(s.agentNodes, agentID)
|
||||
|
||||
// TODO(coadler): actually remove from the netmap, then reenable
|
||||
// the above
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) watchAgentUpdates() {
|
||||
for {
|
||||
conn := s.getAgentConn()
|
||||
nodes, ok := conn.NextUpdate(s.ctx)
|
||||
if !ok {
|
||||
if conn.IsClosed() && s.ctx.Err() == nil {
|
||||
s.reinitCoordinator()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.conn.UpdateNodes(nodes, false)
|
||||
if err != nil {
|
||||
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
|
||||
return *s.agentConn.Load()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) reinitCoordinator() {
|
||||
s.nodesMu.Lock()
|
||||
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
|
||||
s.agentConn.Store(&agentConn)
|
||||
|
||||
// Resubscribe to all of the agents we're tracking.
|
||||
for agentID := range s.agentNodes {
|
||||
err := agentConn.SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
|
||||
type ServerTailnet struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
|
||||
logger slog.Logger
|
||||
conn *tailnet.Conn
|
||||
coord *atomic.Pointer[tailnet.Coordinator]
|
||||
agentConn atomic.Pointer[tailnet.MultiAgentConn]
|
||||
cache *wsconncache.Cache
|
||||
nodesMu sync.Mutex
|
||||
// agentNodes is a map of agent tailnetNodes the server wants to keep a
|
||||
// connection to. It contains the last time the agent was connected to.
|
||||
agentNodes map[uuid.UUID]time.Time
|
||||
// agentTockets holds a map of all open connections to an agent.
|
||||
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) {
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
Title: "Bad Gateway",
|
||||
Description: "Failed to proxy request to application: " + err.Error(),
|
||||
RetryEnabled: true,
|
||||
DashboardURL: dashboardURL.String(),
|
||||
})
|
||||
}
|
||||
proxy.Director = s.director(agentID, proxy.Director)
|
||||
proxy.Transport = s.transport
|
||||
|
||||
return proxy, func() {}, nil
|
||||
}
|
||||
|
||||
type agentIDKey struct{}
|
||||
|
||||
// director makes sure agentIDKey is set on the context in the reverse proxy.
|
||||
// This allows the transport to correctly identify which agent to dial to.
|
||||
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
|
||||
return func(req *http.Request) {
|
||||
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
|
||||
*req = *req.WithContext(ctx)
|
||||
prev(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("no agent id attached")
|
||||
}
|
||||
|
||||
return s.DialAgentNetConn(ctx, agentID, network, addr)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
|
||||
s.nodesMu.Lock()
|
||||
defer s.nodesMu.Unlock()
|
||||
|
||||
_, ok := s.agentNodes[agentID]
|
||||
// If we don't have the node, subscribe.
|
||||
if !ok {
|
||||
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
|
||||
err := s.getAgentConn().SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe agent: %w", err)
|
||||
}
|
||||
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
|
||||
s.agentNodes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
|
||||
var (
|
||||
conn *codersdk.WorkspaceAgentConn
|
||||
ret = func() {}
|
||||
)
|
||||
|
||||
if s.getAgentConn().AgentIsLegacy(agentID) {
|
||||
s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID))
|
||||
cconn, release, err := s.cache.Acquire(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err)
|
||||
}
|
||||
|
||||
conn = cconn.WorkspaceAgentConn
|
||||
ret = release
|
||||
} else {
|
||||
err := s.ensureAgent(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
||||
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: agentID,
|
||||
CloseFunc: func() error { return codersdk.ErrSkipClose },
|
||||
})
|
||||
}
|
||||
|
||||
// Since we now have an open conn, be careful to close it if we error
|
||||
// without returning it to the user.
|
||||
|
||||
reachable := conn.AwaitReachable(ctx)
|
||||
if !reachable {
|
||||
ret()
|
||||
conn.Close()
|
||||
return nil, nil, xerrors.New("agent is unreachable")
|
||||
}
|
||||
|
||||
return conn, ret, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
|
||||
conn, release, err := s.AgentConn(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("acquire agent conn: %w", err)
|
||||
}
|
||||
|
||||
// Since we now have an open conn, be careful to close it if we error
|
||||
// without returning it to the user.
|
||||
|
||||
nc, err := conn.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
release()
|
||||
conn.Close()
|
||||
return nil, xerrors.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
return &netConnCloser{Conn: nc, close: func() {
|
||||
release()
|
||||
conn.Close()
|
||||
}}, err
|
||||
}
|
||||
|
||||
type netConnCloser struct {
|
||||
net.Conn
|
||||
close func()
|
||||
}
|
||||
|
||||
func (c *netConnCloser) Close() error {
|
||||
c.close()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) Close() error {
|
||||
s.cancel()
|
||||
_ = s.cache.Close()
|
||||
_ = s.conn.Close()
|
||||
s.transport.CloseIdleConnections()
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user