feat: Add external provisioner daemons (#4935)

* Start to port over provisioner daemons PR

* Move to Enterprise

* Begin adding tests for external registration

* Move provisioner daemons query to enterprise

* Move around provisioner daemons schema

* Add tags to provisioner daemons

* make gen

* Add user local provisioner daemons

* Add provisioner daemons

* Add feature for external daemons

* Add command to start a provisioner daemon

* Add provisioner tags to template push and create

* Rename migration files

* Fix tests

* Fix entitlements test

* PR comments

* Update migration

* Fix FE types
This commit is contained in:
Kyle Carberry
2022-11-16 16:34:06 -06:00
committed by GitHub
parent 66d20cabac
commit b6703b11c6
51 changed files with 1095 additions and 372 deletions

View File

@ -278,6 +278,7 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa
Type: database.ProvisionerJobTypeWorkspaceBuild,
StorageMethod: priorJob.StorageMethod,
FileID: priorJob.FileID,
Tags: priorJob.Tags,
Input: input,
})
if err != nil {

View File

@ -1,8 +1,10 @@
package coderd
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
@ -18,10 +20,13 @@ import (
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.com/klauspost/compress/zstd"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/tailcfg"
@ -32,17 +37,20 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtype"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/metricscache"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/site"
"github.com/coder/coder/tailnet"
)
@ -323,13 +331,6 @@ func New(options *Options) *API {
r.Get("/{fileID}", api.fileByID)
r.Post("/", api.postFile)
})
r.Route("/provisionerdaemons", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
)
r.Get("/", api.provisionerDaemons)
})
r.Route("/organizations", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
@ -595,18 +596,20 @@ type API struct {
// RootHandler serves "/"
RootHandler chi.Router
metricsCache *metricscache.Cache
siteHandler http.Handler
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
metricsCache *metricscache.Cache
siteHandler http.Handler
WebsocketWaitMutex sync.Mutex
WebsocketWaitGroup sync.WaitGroup
workspaceAgentCache *wsconncache.Cache
}
// Close waits for all WebSocket connections to drain before returning.
func (api *API) Close() error {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Wait()
api.websocketWaitMutex.Unlock()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Wait()
api.WebsocketWaitMutex.Unlock()
api.metricsCache.Close()
coordinator := api.TailnetCoordinator.Load()
@ -635,3 +638,70 @@ func compressHandler(h http.Handler) http.Handler {
return cmp.Handler(h)
}
// CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd
// in the same process.
func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) {
clientSession, serverSession := provisionersdk.TransportPipe()
defer func() {
if err != nil {
_ = clientSession.Close()
_ = serverSession.Close()
}
}()
name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform},
Tags: dbtype.StringMap{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
},
})
if err != nil {
return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err)
}
tags, err := json.Marshal(daemon.Tags)
if err != nil {
return nil, xerrors.Errorf("marshal tags: %w", err)
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
ID: daemon.ID,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
Telemetry: api.Telemetry,
Tags: tags,
QuotaCommitter: &api.QuotaCommitter,
AcquireJobDebounce: debounce,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
})
if err != nil {
return nil, err
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
go func() {
err := server.Serve(ctx, serverSession)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
}
// close the sessions so we don't leak goroutines serving them.
_ = clientSession.Close()
_ = serverSession.Close()
}()
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil
}

View File

@ -19,7 +19,6 @@ import (
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
)
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
@ -204,11 +203,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID),
},
"GET:/api/v2/provisionerdaemons": {
StatusCode: http.StatusOK,
AssertObject: rbac.ResourceProvisionerDaemon,
},
"POST:/api/v2/parameters/{scope}/{id}": {
AssertAction: rbac.ActionUpdate,
AssertObject: rbac.ResourceTemplate,
@ -303,16 +297,6 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
if !ok {
t.Fail()
}
// The provisioner will call to coderd and register itself. This is async,
// so we wait for it to occur.
require.Eventually(t, func() bool {
provisionerds, err := client.ProvisionerDaemons(ctx)
return assert.NoError(t, err) && len(provisionerds) > 0
}, testutil.WaitLong, testutil.IntervalSlow)
provisionerds, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err, "fetch provisioners")
require.Len(t, provisionerds, 1)
organization, err := client.Organization(ctx, admin.OrganizationID)
require.NoError(t, err, "fetch org")

View File

@ -69,7 +69,7 @@ import (
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionerd/proto"
provisionerdproto "github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/tailnet"
@ -328,8 +328,43 @@ func NewProvisionerDaemon(t *testing.T, coderAPI *coderd.API) io.Closer {
assert.NoError(t, err)
}()
closer := provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return coderAPI.ListenProvisionerDaemon(ctx, 0)
closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, 0)
}, &provisionerd.Options{
Filesystem: fs,
Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug),
PollInterval: 50 * time.Millisecond,
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Provisioners: provisionerd.Provisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)),
},
WorkDirectory: t.TempDir(),
})
t.Cleanup(func() {
_ = closer.Close()
})
return closer
}
func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uuid.UUID, tags map[string]string) io.Closer {
echoClient, echoServer := provisionersdk.TransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(func() {
_ = echoClient.Close()
_ = echoServer.Close()
cancelFunc()
})
fs := afero.NewMemMapFs()
go func() {
err := echo.Serve(ctx, fs, &provisionersdk.ServeOptions{
Listener: echoServer,
})
assert.NoError(t, err)
}()
closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return client.ServeProvisionerDaemon(ctx, org, []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho}, tags)
}, &provisionerd.Options{
Filesystem: fs,
Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug),

View File

@ -3,6 +3,7 @@ package databasefake
import (
"context"
"database/sql"
"encoding/json"
"sort"
"strings"
"sync"
@ -146,6 +147,29 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
if !found {
continue
}
tags := map[string]string{}
if arg.Tags != nil {
err := json.Unmarshal(arg.Tags, &tags)
if err != nil {
return provisionerJob, xerrors.Errorf("unmarshal: %w", err)
}
}
missing := false
for key, value := range provisionerJob.Tags {
provided, found := tags[key]
if !found {
missing = true
break
}
if provided != value {
missing = true
break
}
}
if missing {
continue
}
provisionerJob.StartedAt = arg.StartedAt
provisionerJob.UpdatedAt = arg.StartedAt.Time
provisionerJob.WorkerID = arg.WorkerID
@ -2244,6 +2268,7 @@ func (q *fakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.In
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: arg.Tags,
}
q.provisionerDaemons = append(q.provisionerDaemons, daemon)
return daemon, nil
@ -2264,6 +2289,7 @@ func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.Inser
FileID: arg.FileID,
Type: arg.Type,
Input: arg.Input,
Tags: arg.Tags,
}
q.provisionerJobs = append(q.provisionerJobs, job)
return job, nil

View File

@ -0,0 +1,30 @@
package dbtype
import (
"database/sql/driver"
"encoding/json"
"golang.org/x/xerrors"
)
type StringMap map[string]string
func (m *StringMap) Scan(src interface{}) error {
if src == nil {
return nil
}
switch src := src.(type) {
case []byte:
err := json.Unmarshal(src, m)
if err != nil {
return err
}
default:
return xerrors.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, m)
}
return nil
}
func (m StringMap) Value() (driver.Value, error) {
return json.Marshal(m)
}

View File

@ -269,7 +269,8 @@ CREATE TABLE provisioner_daemons (
updated_at timestamp with time zone,
name character varying(64) NOT NULL,
provisioners provisioner_type[] NOT NULL,
replica_id uuid
replica_id uuid,
tags jsonb DEFAULT '{}'::jsonb NOT NULL
);
CREATE TABLE provisioner_job_logs (
@ -306,7 +307,8 @@ CREATE TABLE provisioner_jobs (
type provisioner_job_type NOT NULL,
input jsonb NOT NULL,
worker_id uuid,
file_id uuid NOT NULL
file_id uuid NOT NULL,
tags jsonb DEFAULT '{"scope": "organization"}'::jsonb NOT NULL
);
CREATE TABLE replicas (

View File

@ -0,0 +1,2 @@
ALTER TABLE provisioner_daemons DROP COLUMN tags;
ALTER TABLE provisioner_jobs DROP COLUMN tags;

View File

@ -0,0 +1,5 @@
ALTER TABLE provisioner_daemons ADD COLUMN tags jsonb NOT NULL DEFAULT '{}';
-- We must add the organization scope by default, otherwise pending jobs
-- could be provisioned on new daemons that don't match the tags.
ALTER TABLE provisioner_jobs ADD COLUMN tags jsonb NOT NULL DEFAULT '{"scope":"organization"}';

View File

@ -10,6 +10,7 @@ import (
"fmt"
"time"
"github.com/coder/coder/coderd/database/dbtype"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/tabbed/pqtype"
@ -525,6 +526,7 @@ type ProvisionerDaemon struct {
Name string `db:"name" json:"name"`
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
type ProvisionerJob struct {
@ -543,6 +545,7 @@ type ProvisionerJob struct {
Input json.RawMessage `db:"input" json:"input"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
FileID uuid.UUID `db:"file_id" json:"file_id"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
type ProvisionerJobLog struct {

View File

@ -10,6 +10,7 @@ import (
"encoding/json"
"time"
"github.com/coder/coder/coderd/database/dbtype"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/tabbed/pqtype"
@ -2243,7 +2244,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar
const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one
SELECT
id, created_at, updated_at, name, provisioners, replica_id
id, created_at, updated_at, name, provisioners, replica_id, tags
FROM
provisioner_daemons
WHERE
@ -2260,13 +2261,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID)
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
&i.Tags,
)
return i, err
}
const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many
SELECT
id, created_at, updated_at, name, provisioners, replica_id
id, created_at, updated_at, name, provisioners, replica_id, tags
FROM
provisioner_daemons
`
@ -2287,6 +2289,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
&i.Tags,
); err != nil {
return nil, err
}
@ -2307,10 +2310,11 @@ INSERT INTO
id,
created_at,
"name",
provisioners
provisioners,
tags
)
VALUES
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id
($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at, name, provisioners, replica_id, tags
`
type InsertProvisionerDaemonParams struct {
@ -2318,6 +2322,7 @@ type InsertProvisionerDaemonParams struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
Name string `db:"name" json:"name"`
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error) {
@ -2326,6 +2331,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv
arg.CreatedAt,
arg.Name,
pq.Array(arg.Provisioners),
arg.Tags,
)
var i ProvisionerDaemon
err := row.Scan(
@ -2335,6 +2341,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
&i.Tags,
)
return i, err
}
@ -2487,19 +2494,22 @@ WHERE
AND nested.canceled_at IS NULL
AND nested.completed_at IS NULL
AND nested.provisioner = ANY($3 :: provisioner_type [ ])
-- Ensure the caller satisfies all job tags.
AND nested.tags <@ $4 :: jsonb
ORDER BY
nested.created_at
FOR UPDATE
SKIP LOCKED
LIMIT
1
) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
`
type AcquireProvisionerJobParams struct {
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
Types []ProvisionerType `db:"types" json:"types"`
Tags json.RawMessage `db:"tags" json:"tags"`
}
// Acquires the lock for a single job that isn't started, completed,
@ -2509,7 +2519,12 @@ type AcquireProvisionerJobParams struct {
// multiple provisioners from acquiring the same jobs. See:
// https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE
func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error) {
row := q.db.QueryRowContext(ctx, acquireProvisionerJob, arg.StartedAt, arg.WorkerID, pq.Array(arg.Types))
row := q.db.QueryRowContext(ctx, acquireProvisionerJob,
arg.StartedAt,
arg.WorkerID,
pq.Array(arg.Types),
arg.Tags,
)
var i ProvisionerJob
err := row.Scan(
&i.ID,
@ -2527,13 +2542,14 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
)
return i, err
}
const getProvisionerJobByID = `-- name: GetProvisionerJobByID :one
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
FROM
provisioner_jobs
WHERE
@ -2559,13 +2575,14 @@ func (q *sqlQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (P
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
)
return i, err
}
const getProvisionerJobsByIDs = `-- name: GetProvisionerJobsByIDs :many
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
FROM
provisioner_jobs
WHERE
@ -2597,6 +2614,7 @@ func (q *sqlQuerier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUI
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
); err != nil {
return nil, err
}
@ -2612,7 +2630,7 @@ func (q *sqlQuerier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUI
}
const getProvisionerJobsCreatedAfter = `-- name: GetProvisionerJobsCreatedAfter :many
SELECT id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id FROM provisioner_jobs WHERE created_at > $1
SELECT id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags FROM provisioner_jobs WHERE created_at > $1
`
func (q *sqlQuerier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) {
@ -2640,6 +2658,7 @@ func (q *sqlQuerier) GetProvisionerJobsCreatedAfter(ctx context.Context, created
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
); err != nil {
return nil, err
}
@ -2666,10 +2685,11 @@ INSERT INTO
storage_method,
file_id,
"type",
"input"
"input",
tags
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
`
type InsertProvisionerJobParams struct {
@ -2683,6 +2703,7 @@ type InsertProvisionerJobParams struct {
FileID uuid.UUID `db:"file_id" json:"file_id"`
Type ProvisionerJobType `db:"type" json:"type"`
Input json.RawMessage `db:"input" json:"input"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error) {
@ -2697,6 +2718,7 @@ func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisi
arg.FileID,
arg.Type,
arg.Input,
arg.Tags,
)
var i ProvisionerJob
err := row.Scan(
@ -2715,6 +2737,7 @@ func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisi
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
)
return i, err
}

View File

@ -18,10 +18,11 @@ INSERT INTO
id,
created_at,
"name",
provisioners
provisioners,
tags
)
VALUES
($1, $2, $3, $4) RETURNING *;
($1, $2, $3, $4, $5) RETURNING *;
-- name: UpdateProvisionerDaemonByID :exec
UPDATE

View File

@ -22,6 +22,8 @@ WHERE
AND nested.canceled_at IS NULL
AND nested.completed_at IS NULL
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
-- Ensure the caller satisfies all job tags.
AND nested.tags <@ @tags :: jsonb
ORDER BY
nested.created_at
FOR UPDATE
@ -61,10 +63,11 @@ INSERT INTO
storage_method,
file_id,
"type",
"input"
"input",
tags
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *;
-- name: UpdateProvisionerJobByID :exec
UPDATE

View File

@ -17,6 +17,10 @@ packages:
output_db_file_name: db_tmp.go
overrides:
- column: "provisioner_daemons.tags"
go_type: "github.com/coder/coder/coderd/database/dbtype.StringMap"
- column: "provisioner_jobs.tags"
go_type: "github.com/coder/coder/coderd/database/dbtype.StringMap"
- column: "users.rbac_roles"
go_type: "github.com/lib/pq.StringArray"
- column: "templates.user_acl"

View File

@ -1,113 +0,0 @@
package coderd
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
)
func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
daemons, err := api.Database.GetProvisionerDaemons(ctx)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
if daemons == nil {
daemons = []database.ProvisionerDaemon{}
}
daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, daemons)
}
// ListenProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd
// in the same process.
func (api *API) ListenProvisionerDaemon(ctx context.Context, acquireJobDebounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) {
clientSession, serverSession := provisionersdk.TransportPipe()
defer func() {
if err != nil {
_ = clientSession.Close()
_ = serverSession.Close()
}
}()
name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform},
})
if err != nil {
return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err)
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
ID: daemon.ID,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
Telemetry: api.Telemetry,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
AcquireJobDebounce: acquireJobDebounce,
QuotaCommitter: &api.QuotaCommitter,
})
if err != nil {
return nil, err
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
go func() {
err := server.Serve(ctx, serverSession)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
}
// close the sessions so we don't leak goroutines serving them.
_ = clientSession.Close()
_ = serverSession.Close()
}()
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil
}

View File

@ -1,76 +0,0 @@
package coderd_test
import (
"context"
"crypto/rand"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/testutil"
)
func TestProvisionerDaemons(t *testing.T) {
t.Parallel()
t.Run("PayloadTooBig", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// Takes too long to allocate memory on Windows!
t.Skip()
}
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
data := make([]byte, provisionersdk.MaxMessageSize)
rand.Read(data)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
resp, err := client.Upload(ctx, codersdk.ContentTypeTar, data)
require.NoError(t, err)
t.Log(resp.ID)
version, err := client.CreateTemplateVersion(ctx, user.OrganizationID, codersdk.CreateTemplateVersionRequest{
StorageMethod: codersdk.ProvisionerStorageMethodFile,
FileID: resp.ID,
Provisioner: codersdk.ProvisionerTypeEcho,
})
require.NoError(t, err)
require.Eventually(t, func() bool {
var err error
version, err = client.TemplateVersion(ctx, version.ID)
return assert.NoError(t, err) && version.Job.Error != ""
}, testutil.WaitShort, testutil.IntervalFast)
})
}
func TestProvisionerDaemonsByOrganization(t *testing.T) {
t.Parallel()
t.Run("NoAuth", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.ProvisionerDaemons(ctx)
require.Error(t, err)
})
t.Run("Get", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err)
})
}

View File

@ -39,6 +39,7 @@ type Server struct {
ID uuid.UUID
Logger slog.Logger
Provisioners []database.ProvisionerType
Tags json.RawMessage
Database database.Store
Pubsub database.Pubsub
Telemetry telemetry.Reporter
@ -71,6 +72,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
Valid: true,
},
Types: server.Provisioners,
Tags: server.Tags,
})
if errors.Is(err, sql.ErrNoRows) {
// The provisioner daemon assumes no jobs are available if

View File

@ -0,0 +1,33 @@
package provisionerdserver
import "github.com/google/uuid"
const (
TagScope = "scope"
TagOwner = "owner"
ScopeUser = "user"
ScopeOrganization = "organization"
)
// MutateTags adjusts the "owner" tag dependent on the "scope".
// If the scope is "user", the "owner" is changed to the user ID.
// This is for user-scoped provisioner daemons, where users should
// own their own operations.
func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
if tags == nil {
tags = map[string]string{}
}
_, ok := tags[TagScope]
if !ok {
tags[TagScope] = ScopeOrganization
}
switch tags[TagScope] {
case ScopeUser:
tags[TagOwner] = userID.String()
case ScopeOrganization:
default:
tags[TagScope] = ScopeOrganization
}
return tags
}

View File

@ -131,10 +131,10 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
return
}
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@ -312,6 +312,7 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
CreatedAt: provisionerJob.CreatedAt,
Error: provisionerJob.Error.String,
FileID: provisionerJob.FileID,
Tags: provisionerJob.Tags,
}
// Applying values optional to the struct.
if provisionerJob.StartedAt.Valid {

View File

@ -291,6 +291,8 @@ func (api *API) postTemplateVersionDryRun(rw http.ResponseWriter, r *http.Reques
FileID: job.FileID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: input,
// Copy tags from the previous run.
Tags: job.Tags,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@ -764,6 +766,9 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
return
}
// Ensures the "owner" is properly applied.
tags := provisionerdserver.MutateTags(apiKey.UserID, req.ProvisionerTags)
file, err := api.Database.GetFileByID(ctx, req.FileID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
@ -862,6 +867,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte{'{', '}'},
Tags: tags,
})
if err != nil {
return xerrors.Errorf("insert provisioner job: %w", err)

View File

@ -13,6 +13,7 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
@ -122,6 +123,7 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, "bananas", version.Name)
require.Equal(t, provisionerdserver.ScopeOrganization, version.Job.Tags[provisionerdserver.TagScope])
require.Len(t, auditor.AuditLogs, 1)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[0].Action)

View File

@ -181,10 +181,10 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques
func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
workspace := httpmw.WorkspaceParam(r)
@ -442,10 +442,10 @@ func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request
func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
@ -614,10 +614,10 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
}
}
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
conn, err := websocket.Accept(rw, r, nil)
@ -759,10 +759,10 @@ func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordin
func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)

View File

@ -428,6 +428,8 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
return
}
tags := provisionerdserver.MutateTags(workspace.OwnerID, templateVersionJob.Tags)
// Store prior build number to compute new build number
var priorBuildNum int32
priorHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
@ -513,6 +515,7 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
StorageMethod: templateVersionJob.StorageMethod,
FileID: templateVersionJob.FileID,
Input: input,
Tags: tags,
})
if err != nil {
return xerrors.Errorf("insert provisioner job: %w", err)

View File

@ -373,6 +373,8 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
return
}
tags := provisionerdserver.MutateTags(user.ID, templateVersionJob.Tags)
var (
provisionerJob database.ProvisionerJob
workspaceBuild database.WorkspaceBuild
@ -435,6 +437,7 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
StorageMethod: templateVersionJob.StorageMethod,
FileID: templateVersionJob.FileID,
Input: input,
Tags: tags,
})
if err != nil {
return xerrors.Errorf("insert provisioner job: %w", err)