diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a548ac0da2..9bc3a187f6 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -784,6 +784,20 @@ func (q *querier) GetActiveUserCount(ctx context.Context) (int64, error) { return q.db.GetActiveUserCount(ctx) } +func (q *querier) GetAllTailnetAgents(ctx context.Context) ([]database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return []database.TailnetAgent{}, err + } + return q.db.GetAllTailnetAgents(ctx) +} + +func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return []database.TailnetClient{}, err + } + return q.db.GetAllTailnetClients(ctx) +} + func (q *querier) GetAppSecurityKey(ctx context.Context) (string, error) { // No authz checks return q.db.GetAppSecurityKey(ctx) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 6d8eb2f4e9..8e29bae434 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -903,6 +903,14 @@ func (q *FakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) { return active, nil } +func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAgent, error) { + return nil, ErrUnimplemented +} + +func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.TailnetClient, error) { + return nil, ErrUnimplemented +} + func (q *FakeQuerier) GetAppSecurityKey(_ context.Context) (string, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 9d389ef837..95dde653ca 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -237,6 +237,20 @@ func (m metricsStore) GetActiveUserCount(ctx context.Context) (int64, error) { return count, err } +func (m metricsStore) GetAllTailnetAgents(ctx context.Context) ([]database.TailnetAgent, error) { + start := time.Now() + r0, r1 := m.s.GetAllTailnetAgents(ctx) + m.queryLatencies.WithLabelValues("GetAllTailnetAgents").Observe(time.Since(start).Seconds()) + return r0, r1 +} + +func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) { + start := time.Now() + r0, r1 := m.s.GetAllTailnetClients(ctx) + m.queryLatencies.WithLabelValues("GetAllTailnetClients").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetAppSecurityKey(ctx context.Context) (string, error) { start := time.Now() key, err := m.s.GetAppSecurityKey(ctx) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 5bf58ec5a1..f6ee26a15f 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -371,6 +371,36 @@ func (mr *MockStoreMockRecorder) GetActiveUserCount(arg0 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveUserCount", reflect.TypeOf((*MockStore)(nil).GetActiveUserCount), arg0) } +// GetAllTailnetAgents mocks base method. +func (m *MockStore) GetAllTailnetAgents(arg0 context.Context) ([]database.TailnetAgent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllTailnetAgents", arg0) + ret0, _ := ret[0].([]database.TailnetAgent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllTailnetAgents indicates an expected call of GetAllTailnetAgents. +func (mr *MockStoreMockRecorder) GetAllTailnetAgents(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllTailnetAgents", reflect.TypeOf((*MockStore)(nil).GetAllTailnetAgents), arg0) +} + +// GetAllTailnetClients mocks base method. +func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.TailnetClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllTailnetClients", arg0) + ret0, _ := ret[0].([]database.TailnetClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllTailnetClients indicates an expected call of GetAllTailnetClients. +func (mr *MockStoreMockRecorder) GetAllTailnetClients(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllTailnetClients", reflect.TypeOf((*MockStore)(nil).GetAllTailnetClients), arg0) +} + // GetAppSecurityKey mocks base method. func (m *MockStore) GetAppSecurityKey(arg0 context.Context) (string, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 0baaab0488..be33b00bfc 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -48,6 +48,8 @@ type sqlcQuerier interface { GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) GetActiveUserCount(ctx context.Context) (int64, error) + GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, error) + GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) GetAppSecurityKey(ctx context.Context) (string, error) // GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided // ID. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index ca37955650..87d86e9621 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3694,6 +3694,74 @@ func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetC return i, err } +const getAllTailnetAgents = `-- name: GetAllTailnetAgents :many +SELECT id, coordinator_id, updated_at, node +FROM tailnet_agents +` + +func (q *sqlQuerier) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, error) { + rows, err := q.db.QueryContext(ctx, getAllTailnetAgents) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TailnetAgent + for rows.Next() { + var i TailnetAgent + if err := rows.Scan( + &i.ID, + &i.CoordinatorID, + &i.UpdatedAt, + &i.Node, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getAllTailnetClients = `-- name: GetAllTailnetClients :many +SELECT id, coordinator_id, agent_id, updated_at, node +FROM tailnet_clients +ORDER BY agent_id +` + +func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) { + rows, err := q.db.QueryContext(ctx, getAllTailnetClients) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TailnetClient + for rows.Next() { + var i TailnetClient + if err := rows.Scan( + &i.ID, + &i.CoordinatorID, + &i.AgentID, + &i.UpdatedAt, + &i.Node, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getTailnetAgents = `-- name: GetTailnetAgents :many SELECT id, coordinator_id, updated_at, node FROM tailnet_agents diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql index d16e4a3b4b..fd2db296df 100644 --- a/coderd/database/queries/tailnet.sql +++ b/coderd/database/queries/tailnet.sql @@ -59,11 +59,20 @@ SELECT * FROM tailnet_agents WHERE id = $1; +-- name: GetAllTailnetAgents :many +SELECT * +FROM tailnet_agents; + -- name: GetTailnetClientsForAgent :many SELECT * FROM tailnet_clients WHERE agent_id = $1; +-- name: GetAllTailnetClients :many +SELECT * +FROM tailnet_clients +ORDER BY agent_id; + -- name: UpsertTailnetCoordinator :one INSERT INTO tailnet_coordinators ( diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 9974f803bd..fd0d05149c 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -704,5 +704,7 @@ func (c *haCoordinator) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { c.mutex.RLock() defer c.mutex.RUnlock() - agpl.CoordinatorHTTPDebug(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache)(w, r) + agpl.CoordinatorHTTPDebug( + agpl.HTTPDebugFromLocal(true, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache), + )(w, r) } diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index 25396e9c84..8693d8e9a5 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -13,6 +13,7 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/google/uuid" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -307,6 +308,9 @@ type binding struct { node *agpl.Node } +func (b *binding) isAgent() bool { return b.client == uuid.Nil } +func (b *binding) isClient() bool { return b.client != uuid.Nil } + // binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff. type binder struct { ctx context.Context @@ -386,19 +390,19 @@ func (b *binder) writeOne(bnd binding) error { if err != nil { // this is very bad news, but it should never happen because the node was Unmarshalled by this process // earlier. - b.logger.Error(b.ctx, "failed to marshall node", slog.Error(err)) + b.logger.Error(b.ctx, "failed to marshal node", slog.Error(err)) return err } } switch { - case bnd.client == uuid.Nil && len(nodeRaw) > 0: + case bnd.isAgent() && len(nodeRaw) > 0: _, err = b.store.UpsertTailnetAgent(b.ctx, database.UpsertTailnetAgentParams{ ID: bnd.agent, CoordinatorID: b.coordinatorID, Node: nodeRaw, }) - case bnd.client == uuid.Nil && len(nodeRaw) == 0: + case bnd.isAgent() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{ ID: bnd.agent, CoordinatorID: b.coordinatorID, @@ -407,14 +411,14 @@ func (b *binder) writeOne(bnd binding) error { // treat deletes as idempotent err = nil } - case bnd.client != uuid.Nil && len(nodeRaw) > 0: + case bnd.isClient() && len(nodeRaw) > 0: _, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{ ID: bnd.client, CoordinatorID: b.coordinatorID, AgentID: bnd.agent, Node: nodeRaw, }) - case bnd.client != uuid.Nil && len(nodeRaw) == 0: + case bnd.isClient() && len(nodeRaw) == 0: _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ ID: bnd.client, CoordinatorID: b.coordinatorID, @@ -927,6 +931,27 @@ func (q *querier) updateAll() { } } +func (q *querier) getAll(ctx context.Context) (map[uuid.UUID]database.TailnetAgent, map[uuid.UUID][]database.TailnetClient, error) { + agents, err := q.store.GetAllTailnetAgents(ctx) + if err != nil { + return nil, nil, xerrors.Errorf("get all tailnet agents: %w", err) + } + agentsMap := map[uuid.UUID]database.TailnetAgent{} + for _, agent := range agents { + agentsMap[agent.ID] = agent + } + clients, err := q.store.GetAllTailnetClients(ctx) + if err != nil { + return nil, nil, xerrors.Errorf("get all tailnet clients: %w", err) + } + clientsMap := map[uuid.UUID][]database.TailnetClient{} + for _, client := range clients { + clientsMap[client.AgentID] = append(clientsMap[client.AgentID], client) + } + + return agentsMap, clientsMap, nil +} + func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) { parts := strings.Split(msg, ",") if len(parts) != 2 { @@ -1289,8 +1314,90 @@ func (h *heartbeats) cleanup() { h.logger.Debug(h.ctx, "cleaned up old coordinators") } -func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) { - // TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed. - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("Coder Enterprise PostgreSQL distributed tailnet coordinator")) +func (c *pgCoord) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + debug, err := c.htmlDebug(ctx) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) + return + } + + agpl.CoordinatorHTTPDebug(debug)(w, r) +} + +func (c *pgCoord) htmlDebug(ctx context.Context) (agpl.HTMLDebug, error) { + now := time.Now() + data := agpl.HTMLDebug{} + agents, clients, err := c.querier.getAll(ctx) + if err != nil { + return data, xerrors.Errorf("get all agents and clients: %w", err) + } + + for _, agent := range agents { + htmlAgent := &agpl.HTMLAgent{ + ID: agent.ID, + // Name: ??, TODO: get agent names + LastWriteAge: now.Sub(agent.UpdatedAt).Round(time.Second), + } + for _, conn := range clients[agent.ID] { + htmlAgent.Connections = append(htmlAgent.Connections, &agpl.HTMLClient{ + ID: conn.ID, + Name: conn.ID.String(), + LastWriteAge: now.Sub(conn.UpdatedAt).Round(time.Second), + }) + data.Nodes = append(data.Nodes, &agpl.HTMLNode{ + ID: conn.ID, + Node: conn.Node, + }) + } + slices.SortFunc(htmlAgent.Connections, func(a, b *agpl.HTMLClient) bool { + return a.Name < b.Name + }) + + data.Agents = append(data.Agents, htmlAgent) + data.Nodes = append(data.Nodes, &agpl.HTMLNode{ + ID: agent.ID, + // Name: ??, TODO: get agent names + Node: agent.Node, + }) + } + slices.SortFunc(data.Agents, func(a, b *agpl.HTMLAgent) bool { + return a.Name < b.Name + }) + + for agentID, conns := range clients { + if len(conns) == 0 { + continue + } + + if _, ok := agents[agentID]; ok { + continue + } + agent := &agpl.HTMLAgent{ + Name: "unknown", + ID: agentID, + } + for _, conn := range conns { + agent.Connections = append(agent.Connections, &agpl.HTMLClient{ + Name: conn.ID.String(), + ID: conn.ID, + LastWriteAge: now.Sub(conn.UpdatedAt).Round(time.Second), + }) + data.Nodes = append(data.Nodes, &agpl.HTMLNode{ + ID: conn.ID, + Node: conn.Node, + }) + } + slices.SortFunc(agent.Connections, func(a, b *agpl.HTMLClient) bool { + return a.Name < b.Name + }) + + data.MissingAgents = append(data.MissingAgents, agent) + } + slices.SortFunc(data.MissingAgents, func(a, b *agpl.HTMLAgent) bool { + return a.Name < b.Name + }) + + return data, nil } diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index a5347f981a..d37ea5c290 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -650,21 +650,106 @@ func (c *core) serveHTTPDebug(w http.ResponseWriter, r *http.Request) { c.mutex.RLock() defer c.mutex.RUnlock() - CoordinatorHTTPDebug(false, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache)(w, r) + CoordinatorHTTPDebug( + HTTPDebugFromLocal(false, c.agentSockets, c.agentToConnectionSockets, c.nodes, c.agentNameCache), + )(w, r) } -func CoordinatorHTTPDebug( +func HTTPDebugFromLocal( ha bool, agentSocketsMap map[uuid.UUID]Queue, agentToConnectionSocketsMap map[uuid.UUID]map[uuid.UUID]Queue, nodesMap map[uuid.UUID]*Node, agentNameCache *lru.Cache[uuid.UUID, string], -) func(w http.ResponseWriter, _ *http.Request) { +) HTMLDebug { + now := time.Now() + data := HTMLDebug{HA: ha} + for id, conn := range agentSocketsMap { + start, lastWrite := conn.Stats() + agent := &HTMLAgent{ + Name: conn.Name(), + ID: id, + CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), + LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + Overwrites: int(conn.Overwrites()), + } + + for id, conn := range agentToConnectionSocketsMap[id] { + start, lastWrite := conn.Stats() + agent.Connections = append(agent.Connections, &HTMLClient{ + Name: conn.Name(), + ID: id, + CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), + LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + }) + } + slices.SortFunc(agent.Connections, func(a, b *HTMLClient) bool { + return a.Name < b.Name + }) + + data.Agents = append(data.Agents, agent) + } + slices.SortFunc(data.Agents, func(a, b *HTMLAgent) bool { + return a.Name < b.Name + }) + + for agentID, conns := range agentToConnectionSocketsMap { + if len(conns) == 0 { + continue + } + + if _, ok := agentSocketsMap[agentID]; ok { + continue + } + + agentName, ok := agentNameCache.Get(agentID) + if !ok { + agentName = "unknown" + } + agent := &HTMLAgent{ + Name: agentName, + ID: agentID, + } + for id, conn := range conns { + start, lastWrite := conn.Stats() + agent.Connections = append(agent.Connections, &HTMLClient{ + Name: conn.Name(), + ID: id, + CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), + LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), + }) + } + slices.SortFunc(agent.Connections, func(a, b *HTMLClient) bool { + return a.Name < b.Name + }) + + data.MissingAgents = append(data.MissingAgents, agent) + } + slices.SortFunc(data.MissingAgents, func(a, b *HTMLAgent) bool { + return a.Name < b.Name + }) + + for id, node := range nodesMap { + name, _ := agentNameCache.Get(id) + data.Nodes = append(data.Nodes, &HTMLNode{ + ID: id, + Name: name, + Node: node, + }) + } + slices.SortFunc(data.Nodes, func(a, b *HTMLNode) bool { + return a.Name+a.ID.String() < b.Name+b.ID.String() + }) + + return data +} + +func CoordinatorHTTPDebug(data HTMLDebug) func(w http.ResponseWriter, _ *http.Request) { return func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") tmpl, err := template.New("coordinator_debug").Funcs(template.FuncMap{ - "marshal": func(v interface{}) template.JS { + "marshal": func(v any) template.JS { a, err := json.MarshalIndent(v, "", " ") if err != nil { //nolint:gosec @@ -680,83 +765,6 @@ func CoordinatorHTTPDebug( return } - now := time.Now() - data := htmlDebug{HA: ha} - for id, conn := range agentSocketsMap { - start, lastWrite := conn.Stats() - agent := &htmlAgent{ - Name: conn.Name(), - ID: id, - CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), - LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), - Overwrites: int(conn.Overwrites()), - } - - for id, conn := range agentToConnectionSocketsMap[id] { - start, lastWrite := conn.Stats() - agent.Connections = append(agent.Connections, &htmlClient{ - Name: conn.Name(), - ID: id, - CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), - LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), - }) - } - slices.SortFunc(agent.Connections, func(a, b *htmlClient) bool { - return a.Name < b.Name - }) - - data.Agents = append(data.Agents, agent) - } - slices.SortFunc(data.Agents, func(a, b *htmlAgent) bool { - return a.Name < b.Name - }) - - for agentID, conns := range agentToConnectionSocketsMap { - if len(conns) == 0 { - continue - } - - if _, ok := agentSocketsMap[agentID]; !ok { - agentName, ok := agentNameCache.Get(agentID) - if !ok { - agentName = "unknown" - } - agent := &htmlAgent{ - Name: agentName, - ID: agentID, - } - for id, conn := range conns { - start, lastWrite := conn.Stats() - agent.Connections = append(agent.Connections, &htmlClient{ - Name: conn.Name(), - ID: id, - CreatedAge: now.Sub(time.Unix(start, 0)).Round(time.Second), - LastWriteAge: now.Sub(time.Unix(lastWrite, 0)).Round(time.Second), - }) - } - slices.SortFunc(agent.Connections, func(a, b *htmlClient) bool { - return a.Name < b.Name - }) - - data.MissingAgents = append(data.MissingAgents, agent) - } - } - slices.SortFunc(data.MissingAgents, func(a, b *htmlAgent) bool { - return a.Name < b.Name - }) - - for id, node := range nodesMap { - name, _ := agentNameCache.Get(id) - data.Nodes = append(data.Nodes, &htmlNode{ - ID: id, - Name: name, - Node: node, - }) - } - slices.SortFunc(data.Nodes, func(a, b *htmlNode) bool { - return a.Name+a.ID.String() < b.Name+b.ID.String() - }) - err = tmpl.Execute(w, data) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -766,33 +774,33 @@ func CoordinatorHTTPDebug( } } -type htmlDebug struct { +type HTMLDebug struct { HA bool - Agents []*htmlAgent - MissingAgents []*htmlAgent - Nodes []*htmlNode + Agents []*HTMLAgent + MissingAgents []*HTMLAgent + Nodes []*HTMLNode } -type htmlAgent struct { +type HTMLAgent struct { Name string ID uuid.UUID CreatedAge time.Duration LastWriteAge time.Duration Overwrites int - Connections []*htmlClient + Connections []*HTMLClient } -type htmlClient struct { +type HTMLClient struct { Name string ID uuid.UUID CreatedAge time.Duration LastWriteAge time.Duration } -type htmlNode struct { +type HTMLNode struct { ID uuid.UUID Name string - Node *Node + Node any } var coordinatorDebugTmpl = `