mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
feat(coderd): expire agents from server tailnet (#9092)
This commit is contained in:
@ -407,6 +407,7 @@ func New(options *Options) *API {
|
||||
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
|
||||
},
|
||||
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
|
||||
api.TracerProvider,
|
||||
)
|
||||
if err != nil {
|
||||
panic("failed to setup server tailnet: " + err.Error())
|
||||
|
@ -14,11 +14,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
@ -45,6 +47,7 @@ func NewServerTailnet(
|
||||
derpMap *tailcfg.DERPMap,
|
||||
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
|
||||
cache *wsconncache.Cache,
|
||||
traceProvider trace.TracerProvider,
|
||||
) (*ServerTailnet, error) {
|
||||
logger = logger.Named("servertailnet")
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
@ -58,15 +61,16 @@ func NewServerTailnet(
|
||||
|
||||
serverCtx, cancel := context.WithCancel(ctx)
|
||||
tn := &ServerTailnet{
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
conn: conn,
|
||||
getMultiAgent: getMultiAgent,
|
||||
cache: cache,
|
||||
agentNodes: map[uuid.UUID]time.Time{},
|
||||
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
transport: tailnetTransport.Clone(),
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
tracer: traceProvider.Tracer(tracing.TracerName),
|
||||
conn: conn,
|
||||
getMultiAgent: getMultiAgent,
|
||||
cache: cache,
|
||||
agentConnectionTimes: map[uuid.UUID]time.Time{},
|
||||
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
transport: tailnetTransport.Clone(),
|
||||
}
|
||||
tn.transport.DialContext = tn.dialContext
|
||||
tn.transport.MaxIdleConnsPerHost = 10
|
||||
@ -139,25 +143,50 @@ func (s *ServerTailnet) expireOldAgents() {
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.nodesMu.Lock()
|
||||
agentConn := s.getAgentConn()
|
||||
for agentID, lastConnection := range s.agentNodes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
_ = agentConn
|
||||
// err := agentConn.UnsubscribeAgent(agentID)
|
||||
// if err != nil {
|
||||
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
// }
|
||||
// delete(s.agentNodes, agentID)
|
||||
s.doExpireOldAgents(cutoff)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(coadler): actually remove from the netmap, then reenable
|
||||
// the above
|
||||
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
|
||||
// TODO: add some attrs to this.
|
||||
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
|
||||
defer span.End()
|
||||
|
||||
start := time.Now()
|
||||
deletedCount := 0
|
||||
|
||||
s.nodesMu.Lock()
|
||||
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
|
||||
agentConn := s.getAgentConn()
|
||||
for agentID, lastConnection := range s.agentConnectionTimes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
|
||||
ID: tailnet.NodeID(agentID),
|
||||
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if !deleted {
|
||||
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
|
||||
}
|
||||
|
||||
deletedCount++
|
||||
delete(s.agentConnectionTimes, agentID)
|
||||
err = agentConn.UnsubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
|
||||
slog.F("deleted", deletedCount),
|
||||
slog.F("took", time.Since(start)),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) watchAgentUpdates() {
|
||||
@ -196,7 +225,7 @@ func (s *ServerTailnet) reinitCoordinator() {
|
||||
s.agentConn.Store(&agentConn)
|
||||
|
||||
// Resubscribe to all of the agents we're tracking.
|
||||
for agentID := range s.agentNodes {
|
||||
for agentID := range s.agentConnectionTimes {
|
||||
err := agentConn.SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
@ -212,14 +241,16 @@ type ServerTailnet struct {
|
||||
cancel func()
|
||||
|
||||
logger slog.Logger
|
||||
tracer trace.Tracer
|
||||
conn *tailnet.Conn
|
||||
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
|
||||
agentConn atomic.Pointer[tailnet.MultiAgentConn]
|
||||
cache *wsconncache.Cache
|
||||
nodesMu sync.Mutex
|
||||
// agentNodes is a map of agent tailnetNodes the server wants to keep a
|
||||
// connection to. It contains the last time the agent was connected to.
|
||||
agentNodes map[uuid.UUID]time.Time
|
||||
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
|
||||
// keep a connection to. It contains the last time the agent was connected
|
||||
// to.
|
||||
agentConnectionTimes map[uuid.UUID]time.Time
|
||||
// agentTockets holds a map of all open connections to an agent.
|
||||
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
|
||||
@ -268,7 +299,7 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
|
||||
s.nodesMu.Lock()
|
||||
defer s.nodesMu.Unlock()
|
||||
|
||||
_, ok := s.agentNodes[agentID]
|
||||
_, ok := s.agentConnectionTimes[agentID]
|
||||
// If we don't have the node, subscribe.
|
||||
if !ok {
|
||||
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
|
||||
@ -279,14 +310,27 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
|
||||
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
|
||||
s.agentNodes[agentID] = time.Now()
|
||||
s.agentConnectionTimes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
|
||||
id := uuid.New()
|
||||
s.nodesMu.Lock()
|
||||
s.agentTickets[agentID][id] = struct{}{}
|
||||
s.nodesMu.Unlock()
|
||||
|
||||
return func() {
|
||||
s.nodesMu.Lock()
|
||||
delete(s.agentTickets[agentID], id)
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
|
||||
var (
|
||||
conn *codersdk.WorkspaceAgentConn
|
||||
ret = func() {}
|
||||
ret func()
|
||||
)
|
||||
|
||||
if s.getAgentConn().AgentIsLegacy(agentID) {
|
||||
@ -299,12 +343,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
|
||||
conn = cconn.WorkspaceAgentConn
|
||||
ret = release
|
||||
} else {
|
||||
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
||||
err := s.ensureAgent(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
|
||||
}
|
||||
ret = s.acquireTicket(agentID)
|
||||
|
||||
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
||||
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: agentID,
|
||||
CloseFunc: func() error { return codersdk.ErrSkipClose },
|
||||
@ -317,7 +362,6 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
|
||||
reachable := conn.AwaitReachable(ctx)
|
||||
if !reachable {
|
||||
ret()
|
||||
conn.Close()
|
||||
return nil, nil, xerrors.New("agent is unreachable")
|
||||
}
|
||||
|
||||
@ -336,13 +380,11 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
|
||||
nc, err := conn.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
release()
|
||||
conn.Close()
|
||||
return nil, xerrors.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
return &netConnCloser{Conn: nc, close: func() {
|
||||
release()
|
||||
conn.Close()
|
||||
}}, err
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
@ -232,6 +233,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
|
||||
manifest.DERPMap,
|
||||
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
|
||||
cache,
|
||||
trace.NewNoopTracerProvider(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
Reference in New Issue
Block a user