feat(coderd): expire agents from server tailnet (#9092)

This commit is contained in:
Colin Adler
2023-08-14 20:38:37 -05:00
committed by GitHub
parent a08f7b8fb9
commit 344d32b2f1
6 changed files with 161 additions and 41 deletions

View File

@ -407,6 +407,7 @@ func New(options *Options) *API {
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
},
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
api.TracerProvider,
)
if err != nil {
panic("failed to setup server tailnet: " + err.Error())

View File

@ -14,11 +14,13 @@ import (
"time"
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
@ -45,6 +47,7 @@ func NewServerTailnet(
derpMap *tailcfg.DERPMap,
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
cache *wsconncache.Cache,
traceProvider trace.TracerProvider,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
conn, err := tailnet.NewConn(&tailnet.Options{
@ -58,15 +61,16 @@ func NewServerTailnet(
serverCtx, cancel := context.WithCancel(ctx)
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
getMultiAgent: getMultiAgent,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
ctx: serverCtx,
cancel: cancel,
logger: logger,
tracer: traceProvider.Tracer(tracing.TracerName),
conn: conn,
getMultiAgent: getMultiAgent,
cache: cache,
agentConnectionTimes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
}
tn.transport.DialContext = tn.dialContext
tn.transport.MaxIdleConnsPerHost = 10
@ -139,25 +143,50 @@ func (s *ServerTailnet) expireOldAgents() {
case <-ticker.C:
}
s.nodesMu.Lock()
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentNodes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
_ = agentConn
// err := agentConn.UnsubscribeAgent(agentID)
// if err != nil {
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
// }
// delete(s.agentNodes, agentID)
s.doExpireOldAgents(cutoff)
}
}
// TODO(coadler): actually remove from the netmap, then reenable
// the above
func (s *ServerTailnet) doExpireOldAgents(cutoff time.Duration) {
// TODO: add some attrs to this.
ctx, span := s.tracer.Start(s.ctx, tracing.FuncName())
defer span.End()
start := time.Now()
deletedCount := 0
s.nodesMu.Lock()
s.logger.Debug(ctx, "pruning inactive agents", slog.F("agent_count", len(s.agentConnectionTimes)))
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentConnectionTimes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
deleted, err := s.conn.RemovePeer(tailnet.PeerSelector{
ID: tailnet.NodeID(agentID),
IP: netip.PrefixFrom(tailnet.IPFromUUID(agentID), 128),
})
if err != nil {
s.logger.Warn(ctx, "failed to remove peer from server tailnet", slog.Error(err))
continue
}
if !deleted {
s.logger.Warn(ctx, "peer didn't exist in tailnet", slog.Error(err))
}
deletedCount++
delete(s.agentConnectionTimes, agentID)
err = agentConn.UnsubscribeAgent(agentID)
if err != nil {
s.logger.Error(ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.nodesMu.Unlock()
}
s.nodesMu.Unlock()
s.logger.Debug(s.ctx, "successfully pruned inactive agents",
slog.F("deleted", deletedCount),
slog.F("took", time.Since(start)),
)
}
func (s *ServerTailnet) watchAgentUpdates() {
@ -196,7 +225,7 @@ func (s *ServerTailnet) reinitCoordinator() {
s.agentConn.Store(&agentConn)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
for agentID := range s.agentConnectionTimes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
@ -212,14 +241,16 @@ type ServerTailnet struct {
cancel func()
logger slog.Logger
tracer trace.Tracer
conn *tailnet.Conn
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
// agentNodes is a map of agent tailnetNodes the server wants to keep a
// connection to. It contains the last time the agent was connected to.
agentNodes map[uuid.UUID]time.Time
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
// keep a connection to. It contains the last time the agent was connected
// to.
agentConnectionTimes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
@ -268,7 +299,7 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()
_, ok := s.agentNodes[agentID]
_, ok := s.agentConnectionTimes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
@ -279,14 +310,27 @@ func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}
s.agentNodes[agentID] = time.Now()
s.agentConnectionTimes[agentID] = time.Now()
return nil
}
func (s *ServerTailnet) acquireTicket(agentID uuid.UUID) (release func()) {
id := uuid.New()
s.nodesMu.Lock()
s.agentTickets[agentID][id] = struct{}{}
s.nodesMu.Unlock()
return func() {
s.nodesMu.Lock()
delete(s.agentTickets[agentID], id)
s.nodesMu.Unlock()
}
}
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
var (
conn *codersdk.WorkspaceAgentConn
ret = func() {}
ret func()
)
if s.getAgentConn().AgentIsLegacy(agentID) {
@ -299,12 +343,13 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
conn = cconn.WorkspaceAgentConn
ret = release
} else {
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return codersdk.ErrSkipClose },
@ -317,7 +362,6 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
reachable := conn.AwaitReachable(ctx)
if !reachable {
ret()
conn.Close()
return nil, nil, xerrors.New("agent is unreachable")
}
@ -336,13 +380,11 @@ func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID,
nc, err := conn.DialContext(ctx, network, addr)
if err != nil {
release()
conn.Close()
return nil, xerrors.Errorf("dial context: %w", err)
}
return &netConnCloser{Conn: nc, close: func() {
release()
conn.Close()
}}, err
}

View File

@ -14,6 +14,7 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
@ -232,6 +233,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
manifest.DERPMap,
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
cache,
trace.NewNoopTracerProvider(),
)
require.NoError(t, err)