Daily Active User Metrics (#3735)

* agent: add StatsReporter

* Stabilize protoc
This commit is contained in:
Ammar Bandukwala
2022-09-01 14:58:23 -05:00
committed by GitHub
parent e0cb52ceea
commit 30f8fd9b95
47 changed files with 2006 additions and 279 deletions

View File

@ -30,6 +30,7 @@ import (
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/metricscache"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
@ -76,6 +77,9 @@ type Options struct {
TailscaleEnable bool
TailnetCoordinator *tailnet.Coordinator
DERPMap *tailcfg.DERPMap
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
}
// New constructs a Coder API handler.
@ -121,6 +125,12 @@ func New(options *Options) *API {
panic(xerrors.Errorf("read site bin failed: %w", err))
}
metricsCache := metricscache.New(
options.Database,
options.Logger.Named("metrics_cache"),
options.MetricsCacheRefreshInterval,
)
r := chi.NewRouter()
api := &API{
Options: options,
@ -130,6 +140,7 @@ func New(options *Options) *API {
Authorizer: options.Authorizer,
Logger: options.Logger,
},
metricsCache: metricsCache,
}
if options.TailscaleEnable {
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
@ -147,6 +158,13 @@ func New(options *Options) *API {
httpmw.Recover(api.Logger),
httpmw.Logger(api.Logger),
httpmw.Prometheus(options.PrometheusRegistry),
// Build-Version is helpful for debugging.
func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Build-Version", buildinfo.Version())
next.ServeHTTP(w, r)
})
},
)
apps := func(r chi.Router) {
@ -259,7 +277,7 @@ func New(options *Options) *API {
apiKeyMiddleware,
httpmw.ExtractTemplateParam(options.Database),
)
r.Get("/daus", api.templateDAUs)
r.Get("/", api.template)
r.Delete("/", api.deleteTemplate)
r.Patch("/", api.patchTemplateMeta)
@ -359,11 +377,14 @@ func New(options *Options) *API {
r.Get("/metadata", api.workspaceAgentMetadata)
r.Post("/version", api.postWorkspaceAgentVersion)
r.Get("/listen", api.workspaceAgentListen)
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/coordinate", api.workspaceAgentCoordinate)
r.Get("/report-stats", api.workspaceAgentReportStats)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
@ -452,6 +473,8 @@ type API struct {
websocketWaitGroup sync.WaitGroup
workspaceAgentCache *wsconncache.Cache
httpAuth *HTTPAuthorizer
metricsCache *metricscache.Cache
}
// Close waits for all WebSocket connections to drain before returning.
@ -460,6 +483,8 @@ func (api *API) Close() error {
api.websocketWaitGroup.Wait()
api.websocketWaitMutex.Unlock()
api.metricsCache.Close()
return api.workspaceAgentCache.Close()
}

View File

@ -197,6 +197,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!

View File

@ -234,7 +234,9 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
},
},
},
AutoImportTemplates: options.AutoImportTemplates,
AutoImportTemplates: options.AutoImportTemplates,
MetricsCacheRefreshInterval: time.Millisecond * 100,
AgentStatsRefreshInterval: time.Millisecond * 100,
})
t.Cleanup(func() {
_ = coderAPI.Close()

View File

@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"github.com/lib/pq"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"github.com/coder/coder/coderd/database"
@ -23,6 +24,7 @@ func New() database.Store {
mutex: &sync.RWMutex{},
data: &data{
apiKeys: make([]database.APIKey, 0),
agentStats: make([]database.AgentStat, 0),
organizationMembers: make([]database.OrganizationMember, 0),
organizations: make([]database.Organization, 0),
users: make([]database.User, 0),
@ -78,6 +80,7 @@ type data struct {
userLinks []database.UserLink
// New tables
agentStats []database.AgentStat
auditLogs []database.AuditLog
files []database.File
gitSSHKey []database.GitSSHKey
@ -134,6 +137,64 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
}
return database.ProvisionerJob{}, sql.ErrNoRows
}
func (*fakeQuerier) DeleteOldAgentStats(_ context.Context) error {
// no-op
return nil
}
func (q *fakeQuerier) InsertAgentStat(_ context.Context, p database.InsertAgentStatParams) (database.AgentStat, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
stat := database.AgentStat{
ID: p.ID,
CreatedAt: p.CreatedAt,
WorkspaceID: p.WorkspaceID,
AgentID: p.AgentID,
UserID: p.UserID,
Payload: p.Payload,
TemplateID: p.TemplateID,
}
q.agentStats = append(q.agentStats, stat)
return stat, nil
}
func (q *fakeQuerier) GetTemplateDAUs(_ context.Context, templateID uuid.UUID) ([]database.GetTemplateDAUsRow, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
counts := make(map[time.Time]map[string]struct{})
for _, as := range q.agentStats {
if as.TemplateID != templateID {
continue
}
date := as.CreatedAt.Truncate(time.Hour * 24)
dateEntry := counts[date]
if dateEntry == nil {
dateEntry = make(map[string]struct{})
}
counts[date] = dateEntry
dateEntry[as.UserID.String()] = struct{}{}
}
countKeys := maps.Keys(counts)
sort.Slice(countKeys, func(i, j int) bool {
return countKeys[i].Before(countKeys[j])
})
var rs []database.GetTemplateDAUsRow
for _, key := range countKeys {
rs = append(rs, database.GetTemplateDAUsRow{
Date: key,
Amount: int64(len(counts[key])),
})
}
return rs, nil
}
func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) {
q.mutex.Lock()

View File

@ -87,6 +87,16 @@ CREATE TYPE workspace_transition AS ENUM (
'delete'
);
CREATE TABLE agent_stats (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
user_id uuid NOT NULL,
agent_id uuid NOT NULL,
workspace_id uuid NOT NULL,
template_id uuid NOT NULL,
payload jsonb NOT NULL
);
CREATE TABLE api_keys (
id text NOT NULL,
hashed_secret bytea NOT NULL,
@ -372,6 +382,9 @@ CREATE TABLE workspaces (
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('public.licenses_id_seq'::regclass);
ALTER TABLE ONLY agent_stats
ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
ALTER TABLE ONLY api_keys
ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
@ -468,6 +481,10 @@ ALTER TABLE ONLY workspace_resources
ALTER TABLE ONLY workspaces
ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
CREATE INDEX idx_agent_stats_created_at ON agent_stats USING btree (created_at);
CREATE INDEX idx_agent_stats_user_id ON agent_stats USING btree (user_id);
CREATE INDEX idx_api_keys_user ON api_keys USING btree (user_id);
CREATE INDEX idx_audit_log_organization_id ON audit_logs USING btree (organization_id);

View File

@ -0,0 +1 @@
DROP TABLE agent_stats;

View File

@ -0,0 +1,16 @@
CREATE TABLE agent_stats (
id uuid NOT NULL,
PRIMARY KEY (id),
created_at timestamptz NOT NULL,
user_id uuid NOT NULL,
agent_id uuid NOT NULL,
workspace_id uuid NOT NULL,
template_id uuid NOT NULL,
payload jsonb NOT NULL
);
-- We use created_at for DAU analysis and pruning.
CREATE INDEX idx_agent_stats_created_at ON agent_stats USING btree (created_at);
-- We perform user grouping to analyze DAUs.
CREATE INDEX idx_agent_stats_user_id ON agent_stats USING btree (user_id);

View File

@ -324,6 +324,16 @@ type APIKey struct {
IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"`
}
type AgentStat struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
Payload json.RawMessage `db:"payload" json:"payload"`
}
type AuditLog struct {
ID uuid.UUID `db:"id" json:"id"`
Time time.Time `db:"time" json:"time"`

View File

@ -22,6 +22,7 @@ type querier interface {
DeleteAPIKeyByID(ctx context.Context, id string) error
DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteOldAgentStats(ctx context.Context) error
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
@ -57,6 +58,7 @@ type querier interface {
GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error)
GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error)
GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error)
GetTemplateDAUs(ctx context.Context, templateID uuid.UUID) ([]GetTemplateDAUsRow, error)
GetTemplateVersionByID(ctx context.Context, id uuid.UUID) (TemplateVersion, error)
GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (TemplateVersion, error)
GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg GetTemplateVersionByTemplateIDAndNameParams) (TemplateVersion, error)
@ -99,6 +101,7 @@ type querier interface {
GetWorkspaces(ctx context.Context, arg GetWorkspacesParams) ([]Workspace, error)
GetWorkspacesAutostart(ctx context.Context) ([]Workspace, error)
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
InsertAgentStat(ctx context.Context, arg InsertAgentStatParams) (AgentStat, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertDeploymentID(ctx context.Context, value string) error
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)

View File

@ -15,6 +15,104 @@ import (
"github.com/tabbed/pqtype"
)
const deleteOldAgentStats = `-- name: DeleteOldAgentStats :exec
DELETE FROM AGENT_STATS WHERE created_at < now() - interval '30 days'
`
func (q *sqlQuerier) DeleteOldAgentStats(ctx context.Context) error {
_, err := q.db.ExecContext(ctx, deleteOldAgentStats)
return err
}
const getTemplateDAUs = `-- name: GetTemplateDAUs :many
select
(created_at at TIME ZONE 'UTC')::date as date,
count(distinct(user_id)) as amount
from
agent_stats
where template_id = $1
group by
date
order by
date asc
`
type GetTemplateDAUsRow struct {
Date time.Time `db:"date" json:"date"`
Amount int64 `db:"amount" json:"amount"`
}
func (q *sqlQuerier) GetTemplateDAUs(ctx context.Context, templateID uuid.UUID) ([]GetTemplateDAUsRow, error) {
rows, err := q.db.QueryContext(ctx, getTemplateDAUs, templateID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetTemplateDAUsRow
for rows.Next() {
var i GetTemplateDAUsRow
if err := rows.Scan(&i.Date, &i.Amount); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertAgentStat = `-- name: InsertAgentStat :one
INSERT INTO
agent_stats (
id,
created_at,
user_id,
workspace_id,
template_id,
agent_id,
payload
)
VALUES
($1, $2, $3, $4, $5, $6, $7) RETURNING id, created_at, user_id, agent_id, workspace_id, template_id, payload
`
type InsertAgentStatParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
Payload json.RawMessage `db:"payload" json:"payload"`
}
func (q *sqlQuerier) InsertAgentStat(ctx context.Context, arg InsertAgentStatParams) (AgentStat, error) {
row := q.db.QueryRowContext(ctx, insertAgentStat,
arg.ID,
arg.CreatedAt,
arg.UserID,
arg.WorkspaceID,
arg.TemplateID,
arg.AgentID,
arg.Payload,
)
var i AgentStat
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UserID,
&i.AgentID,
&i.WorkspaceID,
&i.TemplateID,
&i.Payload,
)
return i, err
}
const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec
DELETE
FROM

View File

@ -0,0 +1,28 @@
-- name: InsertAgentStat :one
INSERT INTO
agent_stats (
id,
created_at,
user_id,
workspace_id,
template_id,
agent_id,
payload
)
VALUES
($1, $2, $3, $4, $5, $6, $7) RETURNING *;
-- name: GetTemplateDAUs :many
select
(created_at at TIME ZONE 'UTC')::date as date,
count(distinct(user_id)) as amount
from
agent_stats
where template_id = $1
group by
date
order by
date asc;
-- name: DeleteOldAgentStats :exec
DELETE FROM AGENT_STATS WHERE created_at < now() - interval '30 days';

View File

@ -0,0 +1,172 @@
package metricscache
import (
"context"
"sync/atomic"
"time"
"golang.org/x/xerrors"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/retry"
)
// Cache holds the template DAU cache.
// The aggregation queries responsible for these values can take up to a minute
// on large deployments. Even in small deployments, aggregation queries can
// take a few hundred milliseconds, which would ruin page load times and
// database performance if in the hot path.
type Cache struct {
database database.Store
log slog.Logger
templateDAUResponses atomic.Pointer[map[string]codersdk.TemplateDAUsResponse]
doneCh chan struct{}
cancel func()
interval time.Duration
}
func New(db database.Store, log slog.Logger, interval time.Duration) *Cache {
if interval <= 0 {
interval = time.Hour
}
ctx, cancel := context.WithCancel(context.Background())
c := &Cache{
database: db,
log: log,
doneCh: make(chan struct{}),
cancel: cancel,
interval: interval,
}
go c.run(ctx)
return c
}
func fillEmptyDays(rows []database.GetTemplateDAUsRow) []database.GetTemplateDAUsRow {
var newRows []database.GetTemplateDAUsRow
for i, row := range rows {
if i == 0 {
newRows = append(newRows, row)
continue
}
last := rows[i-1]
const day = time.Hour * 24
diff := row.Date.Sub(last.Date)
for diff > day {
if diff <= day {
break
}
last.Date = last.Date.Add(day)
last.Amount = 0
newRows = append(newRows, last)
diff -= day
}
newRows = append(newRows, row)
continue
}
return newRows
}
func (c *Cache) refresh(ctx context.Context) error {
err := c.database.DeleteOldAgentStats(ctx)
if err != nil {
return xerrors.Errorf("delete old stats: %w", err)
}
templates, err := c.database.GetTemplates(ctx)
if err != nil {
return err
}
templateDAUs := make(map[string]codersdk.TemplateDAUsResponse, len(templates))
for _, template := range templates {
daus, err := c.database.GetTemplateDAUs(ctx, template.ID)
if err != nil {
return err
}
var resp codersdk.TemplateDAUsResponse
for _, ent := range fillEmptyDays(daus) {
resp.Entries = append(resp.Entries, codersdk.DAUEntry{
Date: ent.Date,
Amount: int(ent.Amount),
})
}
templateDAUs[template.ID.String()] = resp
}
c.templateDAUResponses.Store(&templateDAUs)
return nil
}
func (c *Cache) run(ctx context.Context) {
defer close(c.doneCh)
ticker := time.NewTicker(c.interval)
defer ticker.Stop()
for {
for r := retry.New(time.Millisecond*100, time.Minute); r.Wait(ctx); {
start := time.Now()
err := c.refresh(ctx)
if err != nil {
if ctx.Err() != nil {
return
}
c.log.Error(ctx, "refresh", slog.Error(err))
continue
}
c.log.Debug(
ctx,
"metrics refreshed",
slog.F("took", time.Since(start)),
slog.F("interval", c.interval),
)
break
}
select {
case <-ticker.C:
case <-c.doneCh:
return
case <-ctx.Done():
return
}
}
}
func (c *Cache) Close() error {
c.cancel()
<-c.doneCh
return nil
}
// TemplateDAUs returns an empty response if the template doesn't have users
// or is loading for the first time.
func (c *Cache) TemplateDAUs(id uuid.UUID) codersdk.TemplateDAUsResponse {
m := c.templateDAUResponses.Load()
if m == nil {
// Data loading.
return codersdk.TemplateDAUsResponse{}
}
resp, ok := (*m)[id.String()]
if !ok {
// Probably no data.
return codersdk.TemplateDAUsResponse{}
}
return resp
}

View File

@ -0,0 +1,185 @@
package metricscache_test
import (
"context"
"reflect"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/metricscache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
func date(year, month, day int) time.Time {
return time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
}
func TestCache(t *testing.T) {
t.Parallel()
var (
zebra = uuid.New()
tiger = uuid.New()
)
type args struct {
rows []database.InsertAgentStatParams
}
tests := []struct {
name string
args args
want []codersdk.DAUEntry
}{
{"empty", args{}, nil},
{"one hole", args{
rows: []database.InsertAgentStatParams{
{
CreatedAt: date(2022, 8, 27),
UserID: zebra,
},
{
CreatedAt: date(2022, 8, 30),
UserID: zebra,
},
},
}, []codersdk.DAUEntry{
{
Date: date(2022, 8, 27),
Amount: 1,
},
{
Date: date(2022, 8, 28),
Amount: 0,
},
{
Date: date(2022, 8, 29),
Amount: 0,
},
{
Date: date(2022, 8, 30),
Amount: 1,
},
}},
{"no holes", args{
rows: []database.InsertAgentStatParams{
{
CreatedAt: date(2022, 8, 27),
UserID: zebra,
},
{
CreatedAt: date(2022, 8, 28),
UserID: zebra,
},
{
CreatedAt: date(2022, 8, 29),
UserID: zebra,
},
},
}, []codersdk.DAUEntry{
{
Date: date(2022, 8, 27),
Amount: 1,
},
{
Date: date(2022, 8, 28),
Amount: 1,
},
{
Date: date(2022, 8, 29),
Amount: 1,
},
}},
{"holes", args{
rows: []database.InsertAgentStatParams{
{
CreatedAt: date(2022, 1, 1),
UserID: zebra,
},
{
CreatedAt: date(2022, 1, 1),
UserID: tiger,
},
{
CreatedAt: date(2022, 1, 4),
UserID: zebra,
},
{
CreatedAt: date(2022, 1, 7),
UserID: zebra,
},
{
CreatedAt: date(2022, 1, 7),
UserID: tiger,
},
},
}, []codersdk.DAUEntry{
{
Date: date(2022, 1, 1),
Amount: 2,
},
{
Date: date(2022, 1, 2),
Amount: 0,
},
{
Date: date(2022, 1, 3),
Amount: 0,
},
{
Date: date(2022, 1, 4),
Amount: 1,
},
{
Date: date(2022, 1, 5),
Amount: 0,
},
{
Date: date(2022, 1, 6),
Amount: 0,
},
{
Date: date(2022, 1, 7),
Amount: 2,
},
}},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
cache = metricscache.New(db, slogtest.Make(t, nil), time.Millisecond*100)
)
defer cache.Close()
templateID := uuid.New()
db.InsertTemplate(context.Background(), database.InsertTemplateParams{
ID: templateID,
})
for _, row := range tt.args.rows {
row.TemplateID = templateID
db.InsertAgentStat(context.Background(), row)
}
var got codersdk.TemplateDAUsResponse
require.Eventuallyf(t, func() bool {
got = cache.TemplateDAUs(templateID)
return reflect.DeepEqual(got.Entries, tt.want)
}, testutil.WaitShort, testutil.IntervalFast,
"GetDAUs() = %v, want %v", got, tt.want,
)
})
}
}

View File

@ -39,6 +39,8 @@ func TestProvisionerJobLogs_Unit(t *testing.T) {
Pubsub: fPubsub,
}
api := New(&opts)
defer api.Close()
server := httptest.NewServer(api.Handler)
defer server.Close()
userID := uuid.New()

View File

@ -517,6 +517,20 @@ func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(rw, http.StatusOK, convertTemplate(updated, count, createdByNameMap[updated.ID.String()]))
}
func (api *API) templateDAUs(rw http.ResponseWriter, r *http.Request) {
template := httpmw.TemplateParam(r)
if !api.Authorize(r, rbac.ActionRead, template) {
httpapi.ResourceNotFound(rw)
return
}
resp := api.metricsCache.TemplateDAUs(template.ID)
if resp.Entries == nil {
resp.Entries = []codersdk.DAUEntry{}
}
httpapi.Write(rw, http.StatusOK, resp)
}
type autoImportTemplateOpts struct {
name string
archive []byte

View File

@ -10,10 +10,16 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
)
@ -539,3 +545,100 @@ func TestDeleteTemplate(t *testing.T) {
require.Equal(t, http.StatusPreconditionFailed, apiErr.StatusCode())
})
}
func TestTemplateDAUs(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerD: true,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
Logger: slogtest.Make(t, nil),
StatsReporter: agentClient.AgentReportStats,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
})
defer func() {
_ = agentCloser.Close()
}()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
opts := &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client"),
}
daus, err := client.TemplateDAUs(context.Background(), template.ID)
require.NoError(t, err)
require.Equal(t, &codersdk.TemplateDAUsResponse{
Entries: []codersdk.DAUEntry{},
}, daus, "no DAUs when stats are empty")
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, opts)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
sshConn, err := conn.SSHClient()
require.NoError(t, err)
session, err := sshConn.NewSession()
require.NoError(t, err)
_, err = session.Output("echo hello")
require.NoError(t, err)
want := &codersdk.TemplateDAUsResponse{
Entries: []codersdk.DAUEntry{
{
Date: time.Now().UTC().Truncate(time.Hour * 24),
Amount: 1,
},
},
}
require.Eventuallyf(t, func() bool {
daus, err = client.TemplateDAUs(ctx, template.ID)
require.NoError(t, err)
return assert.ObjectsAreEqual(want, daus)
},
testutil.WaitShort, testutil.IntervalFast,
"got %+v != %+v", daus, want,
)
}

View File

@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/netip"
"reflect"
"strconv"
"strings"
"time"
@ -18,6 +19,7 @@ import (
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"tailscale.com/tailcfg"
"cdr.dev/slog"
@ -745,6 +747,130 @@ func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordi
return workspaceAgent, nil
}
func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace resource.",
Detail: err.Error(),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get build.",
Detail: err.Error(),
})
return
}
workspace, err := api.Database.GetWorkspaceByID(r.Context(), build.WorkspaceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
return
}
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer conn.Close(websocket.StatusAbnormalClosure, "")
// Allow overriding the stat interval for debugging and testing purposes.
ctx := r.Context()
timer := time.NewTicker(api.AgentStatsRefreshInterval)
var lastReport codersdk.AgentStatsReportResponse
for {
err := wsjson.Write(ctx, conn, codersdk.AgentStatsReportRequest{})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to write report request.",
Detail: err.Error(),
})
return
}
var rep codersdk.AgentStatsReportResponse
err = wsjson.Read(ctx, conn, &rep)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to read report response.",
Detail: err.Error(),
})
return
}
repJSON, err := json.Marshal(rep)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to marshal stat json.",
Detail: err.Error(),
})
return
}
// Avoid inserting duplicate rows to preserve DB space.
// We will see duplicate reports when on idle connections
// (e.g. web terminal left open) or when there are no connections at
// all.
var insert = !reflect.DeepEqual(lastReport, rep)
api.Logger.Debug(ctx, "read stats report",
slog.F("interval", api.AgentStatsRefreshInterval),
slog.F("agent", workspaceAgent.ID),
slog.F("resource", resource.ID),
slog.F("workspace", workspace.ID),
slog.F("insert", insert),
slog.F("payload", rep),
)
if insert {
lastReport = rep
_, err = api.Database.InsertAgentStat(ctx, database.InsertAgentStatParams{
ID: uuid.New(),
CreatedAt: time.Now(),
AgentID: workspaceAgent.ID,
WorkspaceID: build.WorkspaceID,
UserID: workspace.OwnerID,
TemplateID: workspace.TemplateID,
Payload: json.RawMessage(repJSON),
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to insert agent stat.",
Detail: err.Error(),
})
return
}
}
select {
case <-timer.C:
continue
case <-ctx.Done():
conn.Close(websocket.StatusNormalClosure, "")
return
}
}
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.