feat(coderd): update API to allow filtering provisioner daemons by tags (#15448)

This PR provides new parameters to an endpoint that will be necessary
for #15048
This commit is contained in:
Sas Swart
2024-11-15 11:33:22 +02:00
committed by GitHub
parent 40802958e9
commit 814dd6f854
27 changed files with 389 additions and 80 deletions

View File

@ -1890,7 +1890,7 @@ func (q *querier) GetProvisionerDaemons(ctx context.Context) ([]database.Provisi
return fetchWithPostFilter(q.auth, policy.ActionRead, fetch)(ctx, nil)
}
func (q *querier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
func (q *querier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetProvisionerDaemonsByOrganization)(ctx, organizationID)
}

View File

@ -2066,9 +2066,9 @@ func (s *MethodTestSuite) TestExtraMethods() {
}),
})
s.NoError(err, "insert provisioner daemon")
ds, err := db.GetProvisionerDaemonsByOrganization(context.Background(), org.ID)
ds, err := db.GetProvisionerDaemonsByOrganization(context.Background(), database.GetProvisionerDaemonsByOrganizationParams{OrganizationID: org.ID})
s.NoError(err, "get provisioner daemon by org")
check.Args(org.ID).Asserts(d, policy.ActionRead).Returns(ds)
check.Args(database.GetProvisionerDaemonsByOrganizationParams{OrganizationID: org.ID}).Asserts(d, policy.ActionRead).Returns(ds)
}))
s.Run("DeleteOldProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) {
_, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
@ -2560,7 +2560,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
j := dbgen.ProvisionerJob(s.T(), db, nil, database.ProvisionerJob{
StartedAt: sql.NullTime{Valid: false},
})
check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, Tags: must(json.Marshal(j.Tags))}).
check.Args(database.AcquireProvisionerJobParams{OrganizationID: j.OrganizationID, Types: []database.ProvisionerType{j.Provisioner}, ProvisionerTags: must(json.Marshal(j.Tags))}).
Asserts( /*rbac.ResourceSystem, policy.ActionUpdate*/ )
}))
s.Run("UpdateProvisionerJobWithCompleteByID", s.Subtest(func(db database.Store, check *expects) {

View File

@ -194,8 +194,8 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
UUID: uuid.New(),
Valid: true,
},
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: []byte(`{"scope": "organization"}`),
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: []byte(`{"scope": "organization"}`),
})
require.NoError(b.t, err, "acquire starting job")
if j.ID == job.ID {

View File

@ -531,11 +531,11 @@ func ProvisionerJob(t testing.TB, db database.Store, ps pubsub.Pubsub, orig data
}
if !orig.StartedAt.Time.IsZero() {
job, err = db.AcquireProvisionerJob(genCtx, database.AcquireProvisionerJobParams{
StartedAt: orig.StartedAt,
OrganizationID: job.OrganizationID,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
Tags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
StartedAt: orig.StartedAt,
OrganizationID: job.OrganizationID,
Types: []database.ProvisionerType{database.ProvisionerTypeEcho},
ProvisionerTags: must(json.Marshal(orig.Tags)),
WorkerID: uuid.NullUUID{},
})
require.NoError(t, err)
// There is no easy way to make sure we acquire the correct job.

View File

@ -1194,8 +1194,8 @@ func (q *FakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
continue
}
tags := map[string]string{}
if arg.Tags != nil {
err := json.Unmarshal(arg.Tags, &tags)
if arg.ProvisionerTags != nil {
err := json.Unmarshal(arg.ProvisionerTags, &tags)
if err != nil {
return provisionerJob, xerrors.Errorf("unmarshal: %w", err)
}
@ -3625,16 +3625,28 @@ func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi
return out, nil
}
func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, arg database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
daemons := make([]database.ProvisionerDaemon, 0)
for _, daemon := range q.provisionerDaemons {
if daemon.OrganizationID == organizationID {
daemon.Tags = maps.Clone(daemon.Tags)
daemons = append(daemons, daemon)
if daemon.OrganizationID != arg.OrganizationID {
continue
}
// Special case for untagged provisioners: only match untagged jobs.
// Ref: coderd/database/queries/provisionerjobs.sql:24-30
// CASE WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
// THEN nested.tags :: jsonb = @tags :: jsonb
if tagsEqual(arg.WantTags, tagsUntagged) && !tagsEqual(arg.WantTags, daemon.Tags) {
continue
}
// ELSE nested.tags :: jsonb <@ @tags :: jsonb
if !tagsSubset(arg.WantTags, daemon.Tags) {
continue
}
daemon.Tags = maps.Clone(daemon.Tags)
daemons = append(daemons, daemon)
}
return daemons, nil

View File

@ -959,9 +959,9 @@ func (m queryMetricsStore) GetProvisionerDaemons(ctx context.Context) ([]databas
return daemons, err
}
func (m queryMetricsStore) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
func (m queryMetricsStore) GetProvisionerDaemonsByOrganization(ctx context.Context, arg database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) {
start := time.Now()
r0, r1 := m.s.GetProvisionerDaemonsByOrganization(ctx, organizationID)
r0, r1 := m.s.GetProvisionerDaemonsByOrganization(ctx, arg)
m.queryLatencies.WithLabelValues("GetProvisionerDaemonsByOrganization").Observe(time.Since(start).Seconds())
return r0, r1
}

View File

@ -1973,7 +1973,7 @@ func (mr *MockStoreMockRecorder) GetProvisionerDaemons(arg0 any) *gomock.Call {
}
// GetProvisionerDaemonsByOrganization mocks base method.
func (m *MockStore) GetProvisionerDaemonsByOrganization(arg0 context.Context, arg1 uuid.UUID) ([]database.ProvisionerDaemon, error) {
func (m *MockStore) GetProvisionerDaemonsByOrganization(arg0 context.Context, arg1 database.GetProvisionerDaemonsByOrganizationParams) ([]database.ProvisionerDaemon, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProvisionerDaemonsByOrganization", arg0, arg1)
ret0, _ := ret[0].([]database.ProvisionerDaemon)

View File

@ -198,6 +198,10 @@ CREATE TYPE startup_script_behavior AS ENUM (
'non-blocking'
);
CREATE DOMAIN tagset AS jsonb;
COMMENT ON DOMAIN tagset IS 'A set of tags that match provisioner daemons to provisioner jobs, which can originate from workspaces or templates. tagset is a narrowed type over jsonb. It is expected to be the JSON representation of map[string]string. That is, {"key1": "value1", "key2": "value2"}. We need the narrowed type instead of just using jsonb so that we can give sqlc a type hint, otherwise it defaults to json.RawMessage. json.RawMessage is a suboptimal type to use in the context that we need tagset for.';
CREATE TYPE tailnet_status AS ENUM (
'ok',
'lost'
@ -376,6 +380,21 @@ BEGIN
END;
$$;
CREATE FUNCTION provisioner_tagset_contains(provisioner_tags tagset, job_tags tagset) RETURNS boolean
LANGUAGE plpgsql
AS $$
BEGIN
RETURN CASE
-- Special case for untagged provisioners, where only an exact match should count
WHEN job_tags::jsonb = '{"scope": "organization", "owner": ""}'::jsonb THEN job_tags::jsonb = provisioner_tags::jsonb
-- General case
ELSE job_tags::jsonb <@ provisioner_tags::jsonb
END;
END;
$$;
COMMENT ON FUNCTION provisioner_tagset_contains(provisioner_tags tagset, job_tags tagset) IS 'Returns true if the provisioner_tags contains the job_tags, or if the job_tags represents an untagged provisioner and the superset is exactly equal to the subset.';
CREATE FUNCTION remove_organization_member_role() RETURNS trigger
LANGUAGE plpgsql
AS $$

View File

@ -0,0 +1,3 @@
DROP FUNCTION IF EXISTS provisioner_tagset_contains(tagset, tagset);
DROP DOMAIN IF EXISTS tagset;

View File

@ -0,0 +1,17 @@
CREATE DOMAIN tagset AS jsonb;
COMMENT ON DOMAIN tagset IS 'A set of tags that match provisioner daemons to provisioner jobs, which can originate from workspaces or templates. tagset is a narrowed type over jsonb. It is expected to be the JSON representation of map[string]string. That is, {"key1": "value1", "key2": "value2"}. We need the narrowed type instead of just using jsonb so that we can give sqlc a type hint, otherwise it defaults to json.RawMessage. json.RawMessage is a suboptimal type to use in the context that we need tagset for.';
CREATE OR REPLACE FUNCTION provisioner_tagset_contains(provisioner_tags tagset, job_tags tagset)
RETURNS boolean AS $$
BEGIN
RETURN CASE
-- Special case for untagged provisioners, where only an exact match should count
WHEN job_tags::jsonb = '{"scope": "organization", "owner": ""}'::jsonb THEN job_tags::jsonb = provisioner_tags::jsonb
-- General case
ELSE job_tags::jsonb <@ provisioner_tags::jsonb
END;
END;
$$ LANGUAGE plpgsql;
COMMENT ON FUNCTION provisioner_tagset_contains(tagset, tagset) IS 'Returns true if the provisioner_tags contains the job_tags, or if the job_tags represents an untagged provisioner and the superset is exactly equal to the subset.';

View File

@ -196,7 +196,7 @@ type sqlcQuerier interface {
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetPreviousTemplateVersion(ctx context.Context, arg GetPreviousTemplateVersionParams) (TemplateVersion, error)
GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error)
GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error)
GetProvisionerDaemonsByOrganization(ctx context.Context, arg GetProvisionerDaemonsByOrganizationParams) ([]ProvisionerDaemon, error)
GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uuid.UUID) ([]ProvisionerJobTiming, error)
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)

View File

@ -1020,7 +1020,7 @@ func TestQueuePosition(t *testing.T) {
UUID: uuid.New(),
Valid: true,
},
Tags: json.RawMessage("{}"),
ProvisionerTags: json.RawMessage("{}"),
})
require.NoError(t, err)
require.Equal(t, jobs[0].ID, job.ID)

View File

@ -5269,11 +5269,20 @@ SELECT
FROM
provisioner_daemons
WHERE
organization_id = $1
-- This is the original search criteria:
organization_id = $1 :: uuid
AND
-- adding support for searching by tags:
($2 :: tagset = 'null' :: tagset OR provisioner_tagset_contains(provisioner_daemons.tags::tagset, $2::tagset))
`
func (q *sqlQuerier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error) {
rows, err := q.db.QueryContext(ctx, getProvisionerDaemonsByOrganization, organizationID)
type GetProvisionerDaemonsByOrganizationParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
WantTags StringMap `db:"want_tags" json:"want_tags"`
}
func (q *sqlQuerier) GetProvisionerDaemonsByOrganization(ctx context.Context, arg GetProvisionerDaemonsByOrganizationParams) ([]ProvisionerDaemon, error) {
rows, err := q.db.QueryContext(ctx, getProvisionerDaemonsByOrganization, arg.OrganizationID, arg.WantTags)
if err != nil {
return nil, err
}
@ -5523,21 +5532,17 @@ WHERE
SELECT
id
FROM
provisioner_jobs AS nested
provisioner_jobs AS potential_job
WHERE
nested.started_at IS NULL
AND nested.organization_id = $3
potential_job.started_at IS NULL
AND potential_job.organization_id = $3
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY($4 :: provisioner_type [ ])
AND CASE
-- Special case for untagged provisioners: only match untagged jobs.
WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
THEN nested.tags :: jsonb = $5 :: jsonb
-- Ensure the caller satisfies all job tags.
ELSE nested.tags :: jsonb <@ $5 :: jsonb
END
AND potential_job.provisioner = ANY($4 :: provisioner_type [ ])
-- elsewhere, we use the tagset type, but here we use jsonb for backward compatibility
-- they are aliases and the code that calls this query already relies on a different type
AND provisioner_tagset_contains($5 :: jsonb, potential_job.tags :: jsonb)
ORDER BY
nested.created_at
potential_job.created_at
FOR UPDATE
SKIP LOCKED
LIMIT
@ -5546,11 +5551,11 @@ WHERE
`
type AcquireProvisionerJobParams struct {
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
Types []ProvisionerType `db:"types" json:"types"`
Tags json.RawMessage `db:"tags" json:"tags"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
Types []ProvisionerType `db:"types" json:"types"`
ProvisionerTags json.RawMessage `db:"provisioner_tags" json:"provisioner_tags"`
}
// Acquires the lock for a single job that isn't started, completed,
@ -5565,7 +5570,7 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi
arg.WorkerID,
arg.OrganizationID,
pq.Array(arg.Types),
arg.Tags,
arg.ProvisionerTags,
)
var i ProvisionerJob
err := row.Scan(

View File

@ -10,7 +10,11 @@ SELECT
FROM
provisioner_daemons
WHERE
organization_id = @organization_id;
-- This is the original search criteria:
organization_id = @organization_id :: uuid
AND
-- adding support for searching by tags:
(@want_tags :: tagset = 'null' :: tagset OR provisioner_tagset_contains(provisioner_daemons.tags::tagset, @want_tags::tagset));
-- name: DeleteOldProvisionerDaemons :exec
-- Delete provisioner daemons that have been created at least a week ago

View File

@ -16,21 +16,17 @@ WHERE
SELECT
id
FROM
provisioner_jobs AS nested
provisioner_jobs AS potential_job
WHERE
nested.started_at IS NULL
AND nested.organization_id = @organization_id
potential_job.started_at IS NULL
AND potential_job.organization_id = @organization_id
-- Ensure the caller has the correct provisioner.
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
AND CASE
-- Special case for untagged provisioners: only match untagged jobs.
WHEN nested.tags :: jsonb = '{"scope": "organization", "owner": ""}' :: jsonb
THEN nested.tags :: jsonb = @tags :: jsonb
-- Ensure the caller satisfies all job tags.
ELSE nested.tags :: jsonb <@ @tags :: jsonb
END
AND potential_job.provisioner = ANY(@types :: provisioner_type [ ])
-- elsewhere, we use the tagset type, but here we use jsonb for backward compatibility
-- they are aliases and the code that calls this query already relies on a different type
AND provisioner_tagset_contains(@provisioner_tags :: jsonb, potential_job.tags :: jsonb)
ORDER BY
nested.created_at
potential_job.created_at
FOR UPDATE
SKIP LOCKED
LIMIT
@ -160,4 +156,4 @@ RETURNING *;
-- name: GetProvisionerJobTimingsByJobID :many
SELECT * FROM provisioner_job_timings
WHERE job_id = $1
ORDER BY started_at ASC;
ORDER BY started_at ASC;

View File

@ -35,6 +35,9 @@ sql:
- db_type: "name_organization_pair"
go_type:
type: "NameOrganizationPair"
- db_type: "tagset"
go_type:
type: "StringMap"
- column: "custom_roles.site_permissions"
go_type:
type: "CustomRolePermissions"