mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
chore: Move httpapi, httpmw, & database into coderd
(#568)
* chore: Move httpmw to /coderd directory httpmw is specific to coderd and should be scoped under coderd * chore: Move httpapi to /coderd directory httpapi is specific to coderd and should be scoped under coderd * chore: Move database to /coderd directory database is specific to coderd and should be scoped under coderd * chore: Update codecov & gitattributes for generated files * chore: Update Makefile
This commit is contained in:
@ -10,9 +10,9 @@ import (
|
||||
"google.golang.org/api/idtoken"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/site"
|
||||
)
|
||||
|
||||
|
@ -31,11 +31,11 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/database/databasefake"
|
||||
"github.com/coder/coder/database/postgres"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionerd"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
|
1242
coderd/database/databasefake/databasefake.go
Normal file
1242
coderd/database/databasefake/databasefake.go
Normal file
File diff suppressed because it is too large
Load Diff
75
coderd/database/db.go
Normal file
75
coderd/database/db.go
Normal file
@ -0,0 +1,75 @@
|
||||
// Package database connects to external services for stateful storage.
|
||||
//
|
||||
// Query functions are generated using sqlc.
|
||||
//
|
||||
// To modify the database schema:
|
||||
// 1. Add a new migration using "create_migration.sh" in database/migrations/
|
||||
// 2. Run "make coderd/database/generate" in the root to generate models.
|
||||
// 3. Add/Edit queries in "query.sql" and run "make coderd/database/generate" to create Go code.
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Store contains all queryable database functions.
|
||||
// It extends the generated interface to add transaction support.
|
||||
type Store interface {
|
||||
querier
|
||||
|
||||
InTx(func(Store) error) error
|
||||
}
|
||||
|
||||
// DBTX represents a database connection or transaction.
|
||||
type DBTX interface {
|
||||
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// New creates a new database store using a SQL database connection.
|
||||
func New(sdb *sql.DB) Store {
|
||||
return &sqlQuerier{
|
||||
db: sdb,
|
||||
sdb: sdb,
|
||||
}
|
||||
}
|
||||
|
||||
type sqlQuerier struct {
|
||||
sdb *sql.DB
|
||||
db DBTX
|
||||
}
|
||||
|
||||
// InTx performs database operations inside a transaction.
|
||||
func (q *sqlQuerier) InTx(function func(Store) error) error {
|
||||
if q.sdb == nil {
|
||||
return nil
|
||||
}
|
||||
transaction, err := q.sdb.Begin()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
rerr := transaction.Rollback()
|
||||
if rerr == nil || errors.Is(rerr, sql.ErrTxDone) {
|
||||
// no need to do anything, tx committed successfully
|
||||
return
|
||||
}
|
||||
// couldn't roll back for some reason, extend returned error
|
||||
err = xerrors.Errorf("defer (%s): %w", rerr.Error(), err)
|
||||
}()
|
||||
err = function(&sqlQuerier{db: transaction})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("execute transaction: %w", err)
|
||||
}
|
||||
err = transaction.Commit()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
378
coderd/database/dump.sql
generated
Normal file
378
coderd/database/dump.sql
generated
Normal file
@ -0,0 +1,378 @@
|
||||
-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.
|
||||
|
||||
CREATE TYPE log_level AS ENUM (
|
||||
'trace',
|
||||
'debug',
|
||||
'info',
|
||||
'warn',
|
||||
'error'
|
||||
);
|
||||
|
||||
CREATE TYPE log_source AS ENUM (
|
||||
'provisioner_daemon',
|
||||
'provisioner'
|
||||
);
|
||||
|
||||
CREATE TYPE login_type AS ENUM (
|
||||
'built-in',
|
||||
'saml',
|
||||
'oidc'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_destination_scheme AS ENUM (
|
||||
'none',
|
||||
'environment_variable',
|
||||
'provisioner_variable'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_scope AS ENUM (
|
||||
'organization',
|
||||
'project',
|
||||
'import_job',
|
||||
'user',
|
||||
'workspace'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_source_scheme AS ENUM (
|
||||
'none',
|
||||
'data'
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_type_system AS ENUM (
|
||||
'none',
|
||||
'hcl'
|
||||
);
|
||||
|
||||
CREATE TYPE provisioner_job_type AS ENUM (
|
||||
'project_version_import',
|
||||
'workspace_build'
|
||||
);
|
||||
|
||||
CREATE TYPE provisioner_storage_method AS ENUM (
|
||||
'file'
|
||||
);
|
||||
|
||||
CREATE TYPE provisioner_type AS ENUM (
|
||||
'echo',
|
||||
'terraform'
|
||||
);
|
||||
|
||||
CREATE TYPE userstatus AS ENUM (
|
||||
'active',
|
||||
'dormant',
|
||||
'decommissioned'
|
||||
);
|
||||
|
||||
CREATE TYPE workspace_transition AS ENUM (
|
||||
'start',
|
||||
'stop',
|
||||
'delete'
|
||||
);
|
||||
|
||||
CREATE TABLE api_keys (
|
||||
id text NOT NULL,
|
||||
hashed_secret bytea NOT NULL,
|
||||
user_id text NOT NULL,
|
||||
application boolean NOT NULL,
|
||||
name text NOT NULL,
|
||||
last_used timestamp with time zone NOT NULL,
|
||||
expires_at timestamp with time zone NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
oidc_access_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_id_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
devurl_token boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE files (
|
||||
hash character varying(64) NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
created_by text NOT NULL,
|
||||
mimetype character varying(64) NOT NULL,
|
||||
data bytea NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE licenses (
|
||||
id integer NOT NULL,
|
||||
license jsonb NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE organization_members (
|
||||
organization_id text NOT NULL,
|
||||
user_id text NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
roles text[] DEFAULT '{organization-member}'::text[] NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE organizations (
|
||||
id text NOT NULL,
|
||||
name text NOT NULL,
|
||||
description text NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
"default" boolean DEFAULT false NOT NULL,
|
||||
auto_off_threshold bigint DEFAULT '28800000000000'::bigint NOT NULL,
|
||||
cpu_provisioning_rate real DEFAULT 4.0 NOT NULL,
|
||||
memory_provisioning_rate real DEFAULT 1.0 NOT NULL,
|
||||
workspace_auto_off boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE parameter_schemas (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
job_id uuid NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
description character varying(8192) DEFAULT ''::character varying NOT NULL,
|
||||
default_source_scheme parameter_source_scheme,
|
||||
default_source_value text NOT NULL,
|
||||
allow_override_source boolean NOT NULL,
|
||||
default_destination_scheme parameter_destination_scheme,
|
||||
allow_override_destination boolean NOT NULL,
|
||||
default_refresh text NOT NULL,
|
||||
redisplay_value boolean NOT NULL,
|
||||
validation_error character varying(256) NOT NULL,
|
||||
validation_condition character varying(512) NOT NULL,
|
||||
validation_type_system parameter_type_system NOT NULL,
|
||||
validation_value_type character varying(64) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE parameter_values (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
scope parameter_scope NOT NULL,
|
||||
scope_id text NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
source_scheme parameter_source_scheme NOT NULL,
|
||||
source_value text NOT NULL,
|
||||
destination_scheme parameter_destination_scheme NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE project_versions (
|
||||
id uuid NOT NULL,
|
||||
project_id uuid,
|
||||
organization_id text NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
description character varying(1048576) NOT NULL,
|
||||
job_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE projects (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
organization_id text NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
provisioner provisioner_type NOT NULL,
|
||||
active_version_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE provisioner_daemons (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone,
|
||||
organization_id text,
|
||||
name character varying(64) NOT NULL,
|
||||
provisioners provisioner_type[] NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE provisioner_job_logs (
|
||||
id uuid NOT NULL,
|
||||
job_id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
source log_source NOT NULL,
|
||||
level log_level NOT NULL,
|
||||
output character varying(1024) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE provisioner_jobs (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
started_at timestamp with time zone,
|
||||
canceled_at timestamp with time zone,
|
||||
completed_at timestamp with time zone,
|
||||
error text,
|
||||
organization_id text NOT NULL,
|
||||
initiator_id text NOT NULL,
|
||||
provisioner provisioner_type NOT NULL,
|
||||
storage_method provisioner_storage_method NOT NULL,
|
||||
storage_source text NOT NULL,
|
||||
type provisioner_job_type NOT NULL,
|
||||
input jsonb NOT NULL,
|
||||
worker_id uuid
|
||||
);
|
||||
|
||||
CREATE TABLE users (
|
||||
id text NOT NULL,
|
||||
email text NOT NULL,
|
||||
name text NOT NULL,
|
||||
revoked boolean NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
hashed_password bytea NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
temporary_password boolean DEFAULT false NOT NULL,
|
||||
avatar_hash text DEFAULT ''::text NOT NULL,
|
||||
ssh_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
username text DEFAULT ''::text NOT NULL,
|
||||
dotfiles_git_uri text DEFAULT ''::text NOT NULL,
|
||||
roles text[] DEFAULT '{site-member}'::text[] NOT NULL,
|
||||
status userstatus DEFAULT 'active'::public.userstatus NOT NULL,
|
||||
relatime timestamp with time zone DEFAULT now() NOT NULL,
|
||||
gpg_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
_decomissioned boolean DEFAULT false NOT NULL,
|
||||
shell text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE workspace_agents (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
first_connected_at timestamp with time zone,
|
||||
last_connected_at timestamp with time zone,
|
||||
disconnected_at timestamp with time zone,
|
||||
resource_id uuid NOT NULL,
|
||||
auth_token uuid NOT NULL,
|
||||
auth_instance_id character varying(64),
|
||||
environment_variables jsonb,
|
||||
startup_script character varying(65534),
|
||||
instance_metadata jsonb,
|
||||
resource_metadata jsonb
|
||||
);
|
||||
|
||||
CREATE TABLE workspace_builds (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
workspace_id uuid NOT NULL,
|
||||
project_version_id uuid NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
before_id uuid,
|
||||
after_id uuid,
|
||||
transition workspace_transition NOT NULL,
|
||||
initiator character varying(255) NOT NULL,
|
||||
provisioner_state bytea,
|
||||
job_id uuid NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE workspace_resources (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
job_id uuid NOT NULL,
|
||||
transition workspace_transition NOT NULL,
|
||||
address character varying(256) NOT NULL,
|
||||
type character varying(192) NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
agent_id uuid
|
||||
);
|
||||
|
||||
CREATE TABLE workspaces (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
owner_id text NOT NULL,
|
||||
project_id uuid NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
name character varying(64) NOT NULL
|
||||
);
|
||||
|
||||
ALTER TABLE ONLY files
|
||||
ADD CONSTRAINT files_hash_key UNIQUE (hash);
|
||||
|
||||
ALTER TABLE ONLY parameter_schemas
|
||||
ADD CONSTRAINT parameter_schemas_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY parameter_schemas
|
||||
ADD CONSTRAINT parameter_schemas_job_id_name_key UNIQUE (job_id, name);
|
||||
|
||||
ALTER TABLE ONLY parameter_values
|
||||
ADD CONSTRAINT parameter_values_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY parameter_values
|
||||
ADD CONSTRAINT parameter_values_scope_id_name_key UNIQUE (scope_id, name);
|
||||
|
||||
ALTER TABLE ONLY project_versions
|
||||
ADD CONSTRAINT project_versions_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY project_versions
|
||||
ADD CONSTRAINT project_versions_project_id_name_key UNIQUE (project_id, name);
|
||||
|
||||
ALTER TABLE ONLY projects
|
||||
ADD CONSTRAINT projects_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY projects
|
||||
ADD CONSTRAINT projects_organization_id_name_key UNIQUE (organization_id, name);
|
||||
|
||||
ALTER TABLE ONLY provisioner_daemons
|
||||
ADD CONSTRAINT provisioner_daemons_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY provisioner_daemons
|
||||
ADD CONSTRAINT provisioner_daemons_name_key UNIQUE (name);
|
||||
|
||||
ALTER TABLE ONLY provisioner_job_logs
|
||||
ADD CONSTRAINT provisioner_job_logs_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY provisioner_jobs
|
||||
ADD CONSTRAINT provisioner_jobs_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY workspace_agents
|
||||
ADD CONSTRAINT workspace_agents_auth_token_key UNIQUE (auth_token);
|
||||
|
||||
ALTER TABLE ONLY workspace_agents
|
||||
ADD CONSTRAINT workspace_agents_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY workspace_builds
|
||||
ADD CONSTRAINT workspace_builds_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY workspace_builds
|
||||
ADD CONSTRAINT workspace_builds_job_id_key UNIQUE (job_id);
|
||||
|
||||
ALTER TABLE ONLY workspace_builds
|
||||
ADD CONSTRAINT workspace_builds_workspace_id_name_key UNIQUE (workspace_id, name);
|
||||
|
||||
ALTER TABLE ONLY workspace_resources
|
||||
ADD CONSTRAINT workspace_resources_id_key UNIQUE (id);
|
||||
|
||||
ALTER TABLE ONLY workspaces
|
||||
ADD CONSTRAINT workspaces_id_key UNIQUE (id);
|
||||
|
||||
CREATE UNIQUE INDEX projects_organization_id_name_idx ON projects USING btree (organization_id, name) WHERE (deleted = false);
|
||||
|
||||
CREATE UNIQUE INDEX workspaces_owner_id_name_idx ON workspaces USING btree (owner_id, name) WHERE (deleted = false);
|
||||
|
||||
ALTER TABLE ONLY parameter_schemas
|
||||
ADD CONSTRAINT parameter_schemas_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY project_versions
|
||||
ADD CONSTRAINT project_versions_project_id_fkey FOREIGN KEY (project_id) REFERENCES projects(id);
|
||||
|
||||
ALTER TABLE ONLY provisioner_job_logs
|
||||
ADD CONSTRAINT provisioner_job_logs_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY workspace_agents
|
||||
ADD CONSTRAINT workspace_agents_resource_id_fkey FOREIGN KEY (resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY workspace_builds
|
||||
ADD CONSTRAINT workspace_builds_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY workspace_builds
|
||||
ADD CONSTRAINT workspace_builds_project_version_id_fkey FOREIGN KEY (project_version_id) REFERENCES project_versions(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY workspace_builds
|
||||
ADD CONSTRAINT workspace_builds_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY workspace_resources
|
||||
ADD CONSTRAINT workspace_resources_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY workspaces
|
||||
ADD CONSTRAINT workspaces_project_id_fkey FOREIGN KEY (project_id) REFERENCES projects(id);
|
||||
|
88
coderd/database/dump/main.go
Normal file
88
coderd/database/dump/main.go
Normal file
@ -0,0 +1,88 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
)
|
||||
|
||||
func main() {
|
||||
connection, closeFn, err := postgres.Open()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer closeFn()
|
||||
db, err := sql.Open("postgres", connection)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = database.MigrateUp(db)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cmd := exec.Command(
|
||||
"pg_dump",
|
||||
"--schema-only",
|
||||
connection,
|
||||
"--no-privileges",
|
||||
"--no-owner",
|
||||
"--no-comments",
|
||||
|
||||
// We never want to manually generate
|
||||
// queries executing against this table.
|
||||
"--exclude-table=schema_migrations",
|
||||
)
|
||||
cmd.Env = []string{
|
||||
"PGTZ=UTC",
|
||||
"PGCLIENTENCODING=UTF8",
|
||||
}
|
||||
var output bytes.Buffer
|
||||
cmd.Stdout = &output
|
||||
cmd.Stderr = os.Stderr
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, sed := range []string{
|
||||
// Remove all comments.
|
||||
"/^--/d",
|
||||
// Public is implicit in the schema.
|
||||
"s/ public\\./ /",
|
||||
// Remove database settings.
|
||||
"s/SET.*;//g",
|
||||
// Remove select statements. These aren't useful
|
||||
// to a reader of the dump.
|
||||
"s/SELECT.*;//g",
|
||||
// Removes multiple newlines.
|
||||
"/^$/N;/^\\n$/D",
|
||||
} {
|
||||
cmd := exec.Command("sed", "-e", sed)
|
||||
cmd.Stdin = bytes.NewReader(output.Bytes())
|
||||
output = bytes.Buffer{}
|
||||
cmd.Stdout = &output
|
||||
cmd.Stderr = os.Stderr
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
dump := fmt.Sprintf("-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.\n%s", output.Bytes())
|
||||
_, mainPath, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
panic("couldn't get caller path")
|
||||
}
|
||||
err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
74
coderd/database/migrate.go
Normal file
74
coderd/database/migrate.go
Normal file
@ -0,0 +1,74 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrations embed.FS
|
||||
|
||||
func migrateSetup(db *sql.DB) (*migrate.Migrate, error) {
|
||||
sourceDriver, err := iofs.New(migrations, "migrations")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create iofs: %w", err)
|
||||
}
|
||||
|
||||
dbDriver, err := postgres.WithInstance(db, &postgres.Config{})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("wrap postgres connection: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("new migrate instance: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// MigrateUp runs SQL migrations to ensure the database schema is up-to-date.
|
||||
func MigrateUp(db *sql.DB) error {
|
||||
m, err := migrateSetup(db)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("migrate setup: %w", err)
|
||||
}
|
||||
|
||||
err = m.Up()
|
||||
if err != nil {
|
||||
if errors.Is(err, migrate.ErrNoChange) {
|
||||
// It's OK if no changes happened!
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("up: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateDown runs all down SQL migrations.
|
||||
func MigrateDown(db *sql.DB) error {
|
||||
m, err := migrateSetup(db)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("migrate setup: %w", err)
|
||||
}
|
||||
|
||||
err = m.Down()
|
||||
if err != nil {
|
||||
if errors.Is(err, migrate.ErrNoChange) {
|
||||
// It's OK if no changes happened!
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("down: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
77
coderd/database/migrate_test.go
Normal file
77
coderd/database/migrate_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
//go:build linux
|
||||
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Once", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := testSQLDB(t)
|
||||
|
||||
err := database.MigrateUp(db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Twice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := testSQLDB(t)
|
||||
|
||||
err := database.MigrateUp(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = database.MigrateUp(db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("UpDownUp", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := testSQLDB(t)
|
||||
|
||||
err := database.MigrateUp(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = database.MigrateDown(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = database.MigrateUp(db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func testSQLDB(t testing.TB) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
connection, closeFn, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(closeFn)
|
||||
|
||||
db, err := sql.Open("postgres", connection)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
return db
|
||||
}
|
0
coderd/database/migrations/000001_base.down.sql
Normal file
0
coderd/database/migrations/000001_base.down.sql
Normal file
92
coderd/database/migrations/000001_base.up.sql
Normal file
92
coderd/database/migrations/000001_base.up.sql
Normal file
@ -0,0 +1,92 @@
|
||||
-- This migration creates tables and types for v1 if they do not exist.
|
||||
-- This allows v2 to operate independently of v1, but share data if it exists.
|
||||
--
|
||||
-- All tables and types are stolen from:
|
||||
-- https://github.com/coder/m/blob/47b6fc383347b9f9fab424d829c482defd3e1fe2/product/coder/pkg/database/dump.sql
|
||||
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE login_type AS ENUM (
|
||||
'built-in',
|
||||
'saml',
|
||||
'oidc'
|
||||
);
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE userstatus AS ENUM (
|
||||
'active',
|
||||
'dormant',
|
||||
'decommissioned'
|
||||
);
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id text NOT NULL,
|
||||
email text NOT NULL,
|
||||
name text NOT NULL,
|
||||
revoked boolean NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
hashed_password bytea NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
temporary_password boolean DEFAULT false NOT NULL,
|
||||
avatar_hash text DEFAULT '' :: text NOT NULL,
|
||||
ssh_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
username text DEFAULT '' :: text NOT NULL,
|
||||
dotfiles_git_uri text DEFAULT '' :: text NOT NULL,
|
||||
roles text [] DEFAULT '{site-member}' :: text [] NOT NULL,
|
||||
status userstatus DEFAULT 'active' :: public.userstatus NOT NULL,
|
||||
relatime timestamp with time zone DEFAULT now() NOT NULL,
|
||||
gpg_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
_decomissioned boolean DEFAULT false NOT NULL,
|
||||
shell text DEFAULT '' :: text NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS organizations (
|
||||
id text NOT NULL,
|
||||
name text NOT NULL,
|
||||
description text NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
"default" boolean DEFAULT false NOT NULL,
|
||||
auto_off_threshold bigint DEFAULT '28800000000000' :: bigint NOT NULL,
|
||||
cpu_provisioning_rate real DEFAULT 4.0 NOT NULL,
|
||||
memory_provisioning_rate real DEFAULT 1.0 NOT NULL,
|
||||
workspace_auto_off boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS organization_members (
|
||||
organization_id text NOT NULL,
|
||||
user_id text NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
roles text [] DEFAULT '{organization-member}' :: text [] NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
id text NOT NULL,
|
||||
hashed_secret bytea NOT NULL,
|
||||
user_id text NOT NULL,
|
||||
application boolean NOT NULL,
|
||||
name text NOT NULL,
|
||||
last_used timestamp with time zone NOT NULL,
|
||||
expires_at timestamp with time zone NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
login_type login_type NOT NULL,
|
||||
oidc_access_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_refresh_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_id_token text DEFAULT ''::text NOT NULL,
|
||||
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
|
||||
devurl_token boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS licenses (
|
||||
id integer NOT NULL,
|
||||
license jsonb NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL
|
||||
);
|
6
coderd/database/migrations/000002_projects.down.sql
Normal file
6
coderd/database/migrations/000002_projects.down.sql
Normal file
@ -0,0 +1,6 @@
|
||||
DROP TABLE project_versions;
|
||||
|
||||
DROP TABLE projects;
|
||||
DROP TYPE provisioner_type;
|
||||
|
||||
DROP TABLE files;
|
55
coderd/database/migrations/000002_projects.up.sql
Normal file
55
coderd/database/migrations/000002_projects.up.sql
Normal file
@ -0,0 +1,55 @@
|
||||
-- Store arbitrary data like project source code or avatars.
|
||||
CREATE TABLE files (
|
||||
hash varchar(64) NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
created_by text NOT NULL,
|
||||
mimetype varchar(64) NOT NULL,
|
||||
data bytea NOT NULL
|
||||
);
|
||||
|
||||
CREATE TYPE provisioner_type AS ENUM ('echo', 'terraform');
|
||||
|
||||
-- Project defines infrastructure that your software project
|
||||
-- requires for development.
|
||||
CREATE TABLE projects (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
-- Projects must be scoped to an organization.
|
||||
organization_id text NOT NULL,
|
||||
deleted boolean NOT NULL DEFAULT FALSE,
|
||||
name varchar(64) NOT NULL,
|
||||
provisioner provisioner_type NOT NULL,
|
||||
-- Target's a Project Version to use for Workspaces.
|
||||
-- If a Workspace doesn't match this version, it will be prompted to rebuild.
|
||||
active_version_id uuid NOT NULL,
|
||||
-- Disallow projects to have the same name under
|
||||
-- the same organization.
|
||||
UNIQUE(organization_id, name)
|
||||
);
|
||||
|
||||
-- Enforces no active projects have the same name.
|
||||
CREATE UNIQUE INDEX ON projects (organization_id, name) WHERE deleted = FALSE;
|
||||
|
||||
-- Project Versions store historical project data. When a Project Version is imported,
|
||||
-- an "import" job is queued to parse parameters. A Project Version
|
||||
-- can only be used if the import job succeeds.
|
||||
CREATE TABLE project_versions (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
-- This should be indexed.
|
||||
project_id uuid REFERENCES projects (id),
|
||||
organization_id text NOT NULL,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
-- Name is generated for ease of differentiation.
|
||||
-- eg. TheCozyRabbit16
|
||||
name varchar(64) NOT NULL,
|
||||
-- Extracted from a README.md on import.
|
||||
-- Maximum of 1MB.
|
||||
description varchar(1048576) NOT NULL,
|
||||
-- The job ID for building the project version.
|
||||
job_id uuid NOT NULL,
|
||||
-- Disallow projects to have the same build name
|
||||
-- multiple times.
|
||||
UNIQUE(project_id, name)
|
||||
);
|
2
coderd/database/migrations/000003_workspaces.down.sql
Normal file
2
coderd/database/migrations/000003_workspaces.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
DROP TYPE workspace_transition;
|
||||
DROP TABLE workspaces
|
19
coderd/database/migrations/000003_workspaces.up.sql
Normal file
19
coderd/database/migrations/000003_workspaces.up.sql
Normal file
@ -0,0 +1,19 @@
|
||||
CREATE TABLE workspaces (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
owner_id text NOT NULL,
|
||||
project_id uuid NOT NULL REFERENCES projects (id),
|
||||
deleted boolean NOT NULL DEFAULT FALSE,
|
||||
name varchar(64) NOT NULL
|
||||
);
|
||||
|
||||
-- Enforces no active workspaces have the same name.
|
||||
CREATE UNIQUE INDEX ON workspaces (owner_id, name) WHERE deleted = FALSE;
|
||||
|
||||
CREATE TYPE workspace_transition AS ENUM (
|
||||
'start',
|
||||
'stop',
|
||||
'delete'
|
||||
);
|
||||
|
21
coderd/database/migrations/000004_jobs.down.sql
Normal file
21
coderd/database/migrations/000004_jobs.down.sql
Normal file
@ -0,0 +1,21 @@
|
||||
DROP TABLE workspace_builds;
|
||||
|
||||
DROP TABLE parameter_values;
|
||||
DROP TABLE parameter_schemas;
|
||||
DROP TYPE parameter_destination_scheme;
|
||||
DROP TYPE parameter_source_scheme;
|
||||
DROP TYPE parameter_type_system;
|
||||
DROP TYPE parameter_scope;
|
||||
|
||||
DROP TABLE workspace_agents;
|
||||
DROP TABLE workspace_resources;
|
||||
|
||||
DROP TABLE provisioner_job_logs;
|
||||
DROP TYPE log_source;
|
||||
DROP TYPE log_level;
|
||||
|
||||
DROP TABLE provisioner_jobs;
|
||||
DROP TYPE provisioner_storage_method;
|
||||
DROP TYPE provisioner_job_type;
|
||||
|
||||
DROP TABLE provisioner_daemons;
|
167
coderd/database/migrations/000004_jobs.up.sql
Normal file
167
coderd/database/migrations/000004_jobs.up.sql
Normal file
@ -0,0 +1,167 @@
|
||||
CREATE TABLE IF NOT EXISTS provisioner_daemons (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz,
|
||||
organization_id text,
|
||||
-- Name is generated for ease of differentiation.
|
||||
-- eg. WowBananas16
|
||||
name varchar(64) NOT NULL UNIQUE,
|
||||
provisioners provisioner_type [ ] NOT NULL
|
||||
);
|
||||
|
||||
CREATE TYPE provisioner_job_type AS ENUM (
|
||||
'project_version_import',
|
||||
'workspace_build'
|
||||
);
|
||||
|
||||
CREATE TYPE provisioner_storage_method AS ENUM ('file');
|
||||
|
||||
CREATE TABLE IF NOT EXISTS provisioner_jobs (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
started_at timestamptz,
|
||||
canceled_at timestamptz,
|
||||
completed_at timestamptz,
|
||||
error text,
|
||||
organization_id text NOT NULL,
|
||||
initiator_id text NOT NULL,
|
||||
provisioner provisioner_type NOT NULL,
|
||||
storage_method provisioner_storage_method NOT NULL,
|
||||
storage_source text NOT NULL,
|
||||
type provisioner_job_type NOT NULL,
|
||||
input jsonb NOT NULL,
|
||||
worker_id uuid
|
||||
);
|
||||
|
||||
CREATE TYPE log_level AS ENUM (
|
||||
'trace',
|
||||
'debug',
|
||||
'info',
|
||||
'warn',
|
||||
'error'
|
||||
);
|
||||
|
||||
CREATE TYPE log_source AS ENUM (
|
||||
'provisioner_daemon',
|
||||
'provisioner'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS provisioner_job_logs (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
job_id uuid NOT NULL REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
|
||||
created_at timestamptz NOT NULL,
|
||||
source log_source NOT NULL,
|
||||
level log_level NOT NULL,
|
||||
output varchar(1024) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE workspace_resources (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
job_id uuid NOT NULL REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
|
||||
transition workspace_transition NOT NULL,
|
||||
address varchar(256) NOT NULL,
|
||||
type varchar(192) NOT NULL,
|
||||
name varchar(64) NOT NULL,
|
||||
agent_id uuid
|
||||
);
|
||||
|
||||
CREATE TABLE workspace_agents (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
first_connected_at timestamptz,
|
||||
last_connected_at timestamptz,
|
||||
disconnected_at timestamptz,
|
||||
resource_id uuid NOT NULL REFERENCES workspace_resources (id) ON DELETE CASCADE,
|
||||
auth_token uuid NOT NULL UNIQUE,
|
||||
auth_instance_id varchar(64),
|
||||
environment_variables jsonb,
|
||||
startup_script varchar(65534),
|
||||
instance_metadata jsonb,
|
||||
resource_metadata jsonb
|
||||
);
|
||||
|
||||
CREATE TYPE parameter_scope AS ENUM (
|
||||
'organization',
|
||||
'project',
|
||||
'import_job',
|
||||
'user',
|
||||
'workspace'
|
||||
);
|
||||
|
||||
-- Types of parameters the automator supports.
|
||||
CREATE TYPE parameter_type_system AS ENUM ('none', 'hcl');
|
||||
|
||||
-- Supported schemes for a parameter source.
|
||||
CREATE TYPE parameter_source_scheme AS ENUM('none', 'data');
|
||||
|
||||
-- Supported schemes for a parameter destination.
|
||||
CREATE TYPE parameter_destination_scheme AS ENUM('none', 'environment_variable', 'provisioner_variable');
|
||||
|
||||
-- Stores project version parameters parsed on import.
|
||||
-- No secrets are stored here.
|
||||
--
|
||||
-- All parameter validation occurs server-side to process
|
||||
-- complex validations.
|
||||
--
|
||||
-- Parameter types, description, and validation will produce
|
||||
-- a UI for users to enter values.
|
||||
-- Needs to be made consistent with the examples below.
|
||||
CREATE TABLE parameter_schemas (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
job_id uuid NOT NULL REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
|
||||
name varchar(64) NOT NULL,
|
||||
description varchar(8192) NOT NULL DEFAULT '',
|
||||
default_source_scheme parameter_source_scheme,
|
||||
default_source_value text NOT NULL,
|
||||
-- Allows the user to override the source.
|
||||
allow_override_source boolean NOT null,
|
||||
default_destination_scheme parameter_destination_scheme,
|
||||
-- Allows the user to override the destination.
|
||||
allow_override_destination boolean NOT null,
|
||||
default_refresh text NOT NULL,
|
||||
-- Whether the consumer can view the source and destinations.
|
||||
redisplay_value boolean NOT null,
|
||||
-- This error would appear in the UI if the condition is not met.
|
||||
validation_error varchar(256) NOT NULL,
|
||||
validation_condition varchar(512) NOT NULL,
|
||||
validation_type_system parameter_type_system NOT NULL,
|
||||
validation_value_type varchar(64) NOT NULL,
|
||||
UNIQUE(job_id, name)
|
||||
);
|
||||
|
||||
-- Parameters are provided to jobs for provisioning and to workspaces.
|
||||
CREATE TABLE parameter_values (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
scope parameter_scope NOT NULL,
|
||||
scope_id text NOT NULL,
|
||||
name varchar(64) NOT NULL,
|
||||
source_scheme parameter_source_scheme NOT NULL,
|
||||
source_value text NOT NULL,
|
||||
destination_scheme parameter_destination_scheme NOT NULL,
|
||||
-- Prevents duplicates for parameters in the same scope.
|
||||
UNIQUE(scope_id, name)
|
||||
);
|
||||
|
||||
CREATE TABLE workspace_builds (
|
||||
id uuid NOT NULL UNIQUE,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
workspace_id uuid NOT NULL REFERENCES workspaces (id) ON DELETE CASCADE,
|
||||
project_version_id uuid NOT NULL REFERENCES project_versions (id) ON DELETE CASCADE,
|
||||
name varchar(64) NOT NULL,
|
||||
before_id uuid,
|
||||
after_id uuid,
|
||||
transition workspace_transition NOT NULL,
|
||||
initiator varchar(255) NOT NULL,
|
||||
-- State stored by the provisioner
|
||||
provisioner_state bytea,
|
||||
-- Job ID of the action
|
||||
job_id uuid NOT NULL UNIQUE REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
|
||||
UNIQUE(workspace_id, name)
|
||||
);
|
14
coderd/database/migrations/create_migration.sh
Executable file
14
coderd/database/migrations/create_migration.sh
Executable file
@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "First argument is the migration name!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
migrate create -ext sql -dir . -seq "$1"
|
||||
|
||||
echo "Run \"make gen\" to generate models."
|
466
coderd/database/models.go
Normal file
466
coderd/database/models.go
Normal file
@ -0,0 +1,466 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tabbed/pqtype"
|
||||
)
|
||||
|
||||
type LogLevel string
|
||||
|
||||
const (
|
||||
LogLevelTrace LogLevel = "trace"
|
||||
LogLevelDebug LogLevel = "debug"
|
||||
LogLevelInfo LogLevel = "info"
|
||||
LogLevelWarn LogLevel = "warn"
|
||||
LogLevelError LogLevel = "error"
|
||||
)
|
||||
|
||||
func (e *LogLevel) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = LogLevel(s)
|
||||
case string:
|
||||
*e = LogLevel(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for LogLevel: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type LogSource string
|
||||
|
||||
const (
|
||||
LogSourceProvisionerDaemon LogSource = "provisioner_daemon"
|
||||
LogSourceProvisioner LogSource = "provisioner"
|
||||
)
|
||||
|
||||
func (e *LogSource) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = LogSource(s)
|
||||
case string:
|
||||
*e = LogSource(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for LogSource: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type LoginType string
|
||||
|
||||
const (
|
||||
LoginTypeBuiltIn LoginType = "built-in"
|
||||
LoginTypeSaml LoginType = "saml"
|
||||
LoginTypeOIDC LoginType = "oidc"
|
||||
)
|
||||
|
||||
func (e *LoginType) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = LoginType(s)
|
||||
case string:
|
||||
*e = LoginType(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for LoginType: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ParameterDestinationScheme string
|
||||
|
||||
const (
|
||||
ParameterDestinationSchemeNone ParameterDestinationScheme = "none"
|
||||
ParameterDestinationSchemeEnvironmentVariable ParameterDestinationScheme = "environment_variable"
|
||||
ParameterDestinationSchemeProvisionerVariable ParameterDestinationScheme = "provisioner_variable"
|
||||
)
|
||||
|
||||
func (e *ParameterDestinationScheme) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ParameterDestinationScheme(s)
|
||||
case string:
|
||||
*e = ParameterDestinationScheme(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ParameterDestinationScheme: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ParameterScope string
|
||||
|
||||
const (
|
||||
ParameterScopeOrganization ParameterScope = "organization"
|
||||
ParameterScopeProject ParameterScope = "project"
|
||||
ParameterScopeImportJob ParameterScope = "import_job"
|
||||
ParameterScopeUser ParameterScope = "user"
|
||||
ParameterScopeWorkspace ParameterScope = "workspace"
|
||||
)
|
||||
|
||||
func (e *ParameterScope) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ParameterScope(s)
|
||||
case string:
|
||||
*e = ParameterScope(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ParameterScope: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ParameterSourceScheme string
|
||||
|
||||
const (
|
||||
ParameterSourceSchemeNone ParameterSourceScheme = "none"
|
||||
ParameterSourceSchemeData ParameterSourceScheme = "data"
|
||||
)
|
||||
|
||||
func (e *ParameterSourceScheme) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ParameterSourceScheme(s)
|
||||
case string:
|
||||
*e = ParameterSourceScheme(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ParameterSourceScheme: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ParameterTypeSystem string
|
||||
|
||||
const (
|
||||
ParameterTypeSystemNone ParameterTypeSystem = "none"
|
||||
ParameterTypeSystemHCL ParameterTypeSystem = "hcl"
|
||||
)
|
||||
|
||||
func (e *ParameterTypeSystem) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ParameterTypeSystem(s)
|
||||
case string:
|
||||
*e = ParameterTypeSystem(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ParameterTypeSystem: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ProvisionerJobType string
|
||||
|
||||
const (
|
||||
ProvisionerJobTypeProjectVersionImport ProvisionerJobType = "project_version_import"
|
||||
ProvisionerJobTypeWorkspaceBuild ProvisionerJobType = "workspace_build"
|
||||
)
|
||||
|
||||
func (e *ProvisionerJobType) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ProvisionerJobType(s)
|
||||
case string:
|
||||
*e = ProvisionerJobType(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ProvisionerJobType: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ProvisionerStorageMethod string
|
||||
|
||||
const (
|
||||
ProvisionerStorageMethodFile ProvisionerStorageMethod = "file"
|
||||
)
|
||||
|
||||
func (e *ProvisionerStorageMethod) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ProvisionerStorageMethod(s)
|
||||
case string:
|
||||
*e = ProvisionerStorageMethod(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ProvisionerStorageMethod: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ProvisionerType string
|
||||
|
||||
const (
|
||||
ProvisionerTypeEcho ProvisionerType = "echo"
|
||||
ProvisionerTypeTerraform ProvisionerType = "terraform"
|
||||
)
|
||||
|
||||
func (e *ProvisionerType) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = ProvisionerType(s)
|
||||
case string:
|
||||
*e = ProvisionerType(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for ProvisionerType: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserStatus string
|
||||
|
||||
const (
|
||||
UserstatusActive UserStatus = "active"
|
||||
UserstatusDormant UserStatus = "dormant"
|
||||
UserstatusDecommissioned UserStatus = "decommissioned"
|
||||
)
|
||||
|
||||
func (e *UserStatus) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = UserStatus(s)
|
||||
case string:
|
||||
*e = UserStatus(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for UserStatus: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type WorkspaceTransition string
|
||||
|
||||
const (
|
||||
WorkspaceTransitionStart WorkspaceTransition = "start"
|
||||
WorkspaceTransitionStop WorkspaceTransition = "stop"
|
||||
WorkspaceTransitionDelete WorkspaceTransition = "delete"
|
||||
)
|
||||
|
||||
func (e *WorkspaceTransition) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = WorkspaceTransition(s)
|
||||
case string:
|
||||
*e = WorkspaceTransition(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for WorkspaceTransition: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
|
||||
UserID string `db:"user_id" json:"user_id"`
|
||||
Application bool `db:"application" json:"application"`
|
||||
Name string `db:"name" json:"name"`
|
||||
LastUsed time.Time `db:"last_used" json:"last_used"`
|
||||
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
|
||||
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
|
||||
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
|
||||
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
|
||||
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Hash string `db:"hash" json:"hash"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
CreatedBy string `db:"created_by" json:"created_by"`
|
||||
Mimetype string `db:"mimetype" json:"mimetype"`
|
||||
Data []byte `db:"data" json:"data"`
|
||||
}
|
||||
|
||||
type License struct {
|
||||
ID int32 `db:"id" json:"id"`
|
||||
License json.RawMessage `db:"license" json:"license"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type Organization struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Default bool `db:"default" json:"default"`
|
||||
AutoOffThreshold int64 `db:"auto_off_threshold" json:"auto_off_threshold"`
|
||||
CpuProvisioningRate float32 `db:"cpu_provisioning_rate" json:"cpu_provisioning_rate"`
|
||||
MemoryProvisioningRate float32 `db:"memory_provisioning_rate" json:"memory_provisioning_rate"`
|
||||
WorkspaceAutoOff bool `db:"workspace_auto_off" json:"workspace_auto_off"`
|
||||
}
|
||||
|
||||
type OrganizationMember struct {
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
UserID string `db:"user_id" json:"user_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Roles []string `db:"roles" json:"roles"`
|
||||
}
|
||||
|
||||
type ParameterSchema struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
JobID uuid.UUID `db:"job_id" json:"job_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
DefaultSourceScheme ParameterSourceScheme `db:"default_source_scheme" json:"default_source_scheme"`
|
||||
DefaultSourceValue string `db:"default_source_value" json:"default_source_value"`
|
||||
AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"`
|
||||
DefaultDestinationScheme ParameterDestinationScheme `db:"default_destination_scheme" json:"default_destination_scheme"`
|
||||
AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"`
|
||||
DefaultRefresh string `db:"default_refresh" json:"default_refresh"`
|
||||
RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"`
|
||||
ValidationError string `db:"validation_error" json:"validation_error"`
|
||||
ValidationCondition string `db:"validation_condition" json:"validation_condition"`
|
||||
ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"`
|
||||
ValidationValueType string `db:"validation_value_type" json:"validation_value_type"`
|
||||
}
|
||||
|
||||
type ParameterValue struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Scope ParameterScope `db:"scope" json:"scope"`
|
||||
ScopeID string `db:"scope_id" json:"scope_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
SourceScheme ParameterSourceScheme `db:"source_scheme" json:"source_scheme"`
|
||||
SourceValue string `db:"source_value" json:"source_value"`
|
||||
DestinationScheme ParameterDestinationScheme `db:"destination_scheme" json:"destination_scheme"`
|
||||
}
|
||||
|
||||
type Project struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
|
||||
ActiveVersionID uuid.UUID `db:"active_version_id" json:"active_version_id"`
|
||||
}
|
||||
|
||||
type ProjectVersion struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ProjectID uuid.NullUUID `db:"project_id" json:"project_id"`
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Description string `db:"description" json:"description"`
|
||||
JobID uuid.UUID `db:"job_id" json:"job_id"`
|
||||
}
|
||||
|
||||
type ProvisionerDaemon struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
|
||||
OrganizationID sql.NullString `db:"organization_id" json:"organization_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
|
||||
}
|
||||
|
||||
type ProvisionerJob struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
CanceledAt sql.NullTime `db:"canceled_at" json:"canceled_at"`
|
||||
CompletedAt sql.NullTime `db:"completed_at" json:"completed_at"`
|
||||
Error sql.NullString `db:"error" json:"error"`
|
||||
OrganizationID string `db:"organization_id" json:"organization_id"`
|
||||
InitiatorID string `db:"initiator_id" json:"initiator_id"`
|
||||
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
|
||||
StorageMethod ProvisionerStorageMethod `db:"storage_method" json:"storage_method"`
|
||||
StorageSource string `db:"storage_source" json:"storage_source"`
|
||||
Type ProvisionerJobType `db:"type" json:"type"`
|
||||
Input json.RawMessage `db:"input" json:"input"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
}
|
||||
|
||||
type ProvisionerJobLog struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
JobID uuid.UUID `db:"job_id" json:"job_id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
Source LogSource `db:"source" json:"source"`
|
||||
Level LogLevel `db:"level" json:"level"`
|
||||
Output string `db:"output" json:"output"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Email string `db:"email" json:"email"`
|
||||
Name string `db:"name" json:"name"`
|
||||
Revoked bool `db:"revoked" json:"revoked"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
HashedPassword []byte `db:"hashed_password" json:"hashed_password"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
TemporaryPassword bool `db:"temporary_password" json:"temporary_password"`
|
||||
AvatarHash string `db:"avatar_hash" json:"avatar_hash"`
|
||||
SshKeyRegeneratedAt time.Time `db:"ssh_key_regenerated_at" json:"ssh_key_regenerated_at"`
|
||||
Username string `db:"username" json:"username"`
|
||||
DotfilesGitUri string `db:"dotfiles_git_uri" json:"dotfiles_git_uri"`
|
||||
Roles []string `db:"roles" json:"roles"`
|
||||
Status UserStatus `db:"status" json:"status"`
|
||||
Relatime time.Time `db:"relatime" json:"relatime"`
|
||||
GpgKeyRegeneratedAt time.Time `db:"gpg_key_regenerated_at" json:"gpg_key_regenerated_at"`
|
||||
Decomissioned bool `db:"_decomissioned" json:"_decomissioned"`
|
||||
Shell string `db:"shell" json:"shell"`
|
||||
}
|
||||
|
||||
type Workspace struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
OwnerID string `db:"owner_id" json:"owner_id"`
|
||||
ProjectID uuid.UUID `db:"project_id" json:"project_id"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
type WorkspaceAgent struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
|
||||
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
|
||||
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
|
||||
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
|
||||
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
|
||||
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
|
||||
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
|
||||
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
|
||||
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
|
||||
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
|
||||
}
|
||||
|
||||
type WorkspaceBuild struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
ProjectVersionID uuid.UUID `db:"project_version_id" json:"project_version_id"`
|
||||
Name string `db:"name" json:"name"`
|
||||
BeforeID uuid.NullUUID `db:"before_id" json:"before_id"`
|
||||
AfterID uuid.NullUUID `db:"after_id" json:"after_id"`
|
||||
Transition WorkspaceTransition `db:"transition" json:"transition"`
|
||||
Initiator string `db:"initiator" json:"initiator"`
|
||||
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
|
||||
JobID uuid.UUID `db:"job_id" json:"job_id"`
|
||||
}
|
||||
|
||||
type WorkspaceResource struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
JobID uuid.UUID `db:"job_id" json:"job_id"`
|
||||
Transition WorkspaceTransition `db:"transition" json:"transition"`
|
||||
Address string `db:"address" json:"address"`
|
||||
Type string `db:"type" json:"type"`
|
||||
Name string `db:"name" json:"name"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
}
|
142
coderd/database/postgres/postgres.go
Normal file
142
coderd/database/postgres/postgres.go
Normal file
@ -0,0 +1,142 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
// Required to prevent port collision during container creation.
|
||||
// Super unlikely, but it happened. See: https://github.com/coder/coder/runs/5375197003
|
||||
var openPortMutex sync.Mutex
|
||||
|
||||
// Open creates a new PostgreSQL server using a Docker container.
|
||||
func Open() (string, func(), error) {
|
||||
if os.Getenv("DB") == "ci" {
|
||||
// In CI, creating a Docker container for each test is slow.
|
||||
// This expects a PostgreSQL instance with the hardcoded credentials
|
||||
// available.
|
||||
dbURL := "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable"
|
||||
db, err := sql.Open("postgres", dbURL)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("connect to ci postgres: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
dbName, err := cryptorand.StringCharset(cryptorand.Lower, 10)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("generate db name: %w", err)
|
||||
}
|
||||
dbName = "ci" + dbName
|
||||
_, err = db.Exec("CREATE DATABASE " + dbName)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("create db: %w", err)
|
||||
}
|
||||
return "postgres://postgres:postgres@127.0.0.1:5432/" + dbName + "?sslmode=disable", func() {}, nil
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("create pool: %w", err)
|
||||
}
|
||||
tempDir, err := ioutil.TempDir(os.TempDir(), "postgres")
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("create tempdir: %w", err)
|
||||
}
|
||||
openPortMutex.Lock()
|
||||
// Pick an explicit port on the host to connect to 5432.
|
||||
// This is necessary so we can configure the port to only use ipv4.
|
||||
port, err := getFreePort()
|
||||
if err != nil {
|
||||
openPortMutex.Unlock()
|
||||
return "", nil, xerrors.Errorf("Unable to get free port: %w", err)
|
||||
}
|
||||
|
||||
resource, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "postgres",
|
||||
Tag: "11",
|
||||
Env: []string{
|
||||
"POSTGRES_PASSWORD=postgres",
|
||||
"POSTGRES_USER=postgres",
|
||||
"POSTGRES_DB=postgres",
|
||||
// The location for temporary database files!
|
||||
"PGDATA=/tmp",
|
||||
"listen_addresses = '*'",
|
||||
},
|
||||
PortBindings: map[docker.Port][]docker.PortBinding{
|
||||
"5432/tcp": {{
|
||||
// Manually specifying a host IP tells Docker just to use an IPV4 address.
|
||||
// If we don't do this, we hit a fun bug:
|
||||
// https://github.com/moby/moby/issues/42442
|
||||
// where the ipv4 and ipv6 ports might be _different_ and collide with other running docker containers.
|
||||
HostIP: "0.0.0.0",
|
||||
HostPort: fmt.Sprintf("%d", port)}},
|
||||
},
|
||||
Mounts: []string{
|
||||
// The postgres image has a VOLUME parameter in it's image.
|
||||
// If we don't mount at this point, Docker will allocate a
|
||||
// volume for this directory.
|
||||
//
|
||||
// This isn't used anyways, since we override PGDATA.
|
||||
fmt.Sprintf("%s:/var/lib/postgresql/data", tempDir),
|
||||
},
|
||||
}, func(config *docker.HostConfig) {
|
||||
// set AutoRemove to true so that stopped container goes away by itself
|
||||
config.AutoRemove = true
|
||||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||||
})
|
||||
if err != nil {
|
||||
openPortMutex.Unlock()
|
||||
return "", nil, xerrors.Errorf("could not start resource: %w", err)
|
||||
}
|
||||
openPortMutex.Unlock()
|
||||
|
||||
hostAndPort := resource.GetHostPort("5432/tcp")
|
||||
dbURL := fmt.Sprintf("postgres://postgres:postgres@%s/postgres?sslmode=disable", hostAndPort)
|
||||
|
||||
// Docker should hard-kill the container after 120 seconds.
|
||||
err = resource.Expire(120)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("could not expire resource: %w", err)
|
||||
}
|
||||
|
||||
pool.MaxWait = 120 * time.Second
|
||||
err = pool.Retry(func() error {
|
||||
db, err := sql.Open("postgres", dbURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.Ping()
|
||||
_ = db.Close()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return dbURL, func() {
|
||||
_ = pool.Purge(resource)
|
||||
_ = os.RemoveAll(tempDir)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getFreePort asks the kernel for a free open port that is ready to use.
|
||||
func getFreePort() (port int, err error) {
|
||||
// Binding to port 0 tells the OS to grab a port for us:
|
||||
// https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer listener.Close()
|
||||
return listener.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
38
coderd/database/postgres/postgres_test.go
Normal file
38
coderd/database/postgres/postgres_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
//go:build linux
|
||||
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestPostgres(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
return
|
||||
}
|
||||
|
||||
connect, close, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer close()
|
||||
db, err := sql.Open("postgres", connect)
|
||||
require.NoError(t, err)
|
||||
err = db.Ping()
|
||||
require.NoError(t, err)
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
155
coderd/database/pubsub.go
Normal file
155
coderd/database/pubsub.go
Normal file
@ -0,0 +1,155 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Listener represents a pubsub handler.
|
||||
type Listener func(ctx context.Context, message []byte)
|
||||
|
||||
// Pubsub is a generic interface for broadcasting and receiving messages.
|
||||
// Implementors should assume high-availability with the backing implementation.
|
||||
type Pubsub interface {
|
||||
Subscribe(event string, listener Listener) (cancel func(), err error)
|
||||
Publish(event string, message []byte) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Pubsub implementation using PostgreSQL.
|
||||
type pgPubsub struct {
|
||||
pgListener *pq.Listener
|
||||
db *sql.DB
|
||||
mut sync.Mutex
|
||||
listeners map[string]map[string]Listener
|
||||
}
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
|
||||
err = p.pgListener.Listen(event)
|
||||
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
|
||||
// It's ok if it's already open!
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen: %w", err)
|
||||
}
|
||||
|
||||
var listeners map[string]Listener
|
||||
var ok bool
|
||||
if listeners, ok = p.listeners[event]; !ok {
|
||||
listeners = map[string]Listener{}
|
||||
p.listeners[event] = listeners
|
||||
}
|
||||
var id string
|
||||
for {
|
||||
id = uuid.New().String()
|
||||
if _, ok = listeners[id]; !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
listeners[id] = listener
|
||||
return func() {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
listeners := p.listeners[event]
|
||||
delete(listeners, id)
|
||||
|
||||
if len(listeners) == 0 {
|
||||
_ = p.pgListener.Unlisten(event)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *pgPubsub) Publish(event string, message []byte) error {
|
||||
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("exec: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the pubsub instance.
|
||||
func (p *pgPubsub) Close() error {
|
||||
return p.pgListener.Close()
|
||||
}
|
||||
|
||||
// listen begins receiving messages on the pq listener.
|
||||
func (p *pgPubsub) listen(ctx context.Context) {
|
||||
var (
|
||||
notif *pq.Notification
|
||||
ok bool
|
||||
)
|
||||
defer p.pgListener.Close()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case notif, ok = <-p.pgListener.Notify:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
// A nil notification can be dispatched on reconnect.
|
||||
if notif == nil {
|
||||
continue
|
||||
}
|
||||
p.listenReceive(ctx, notif)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
listeners, ok := p.listeners[notif.Channel]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
extra := []byte(notif.Extra)
|
||||
for _, listener := range listeners {
|
||||
go listener(ctx, extra)
|
||||
}
|
||||
}
|
||||
|
||||
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
|
||||
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
|
||||
// Creates a new listener using pq.
|
||||
errCh := make(chan error)
|
||||
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
|
||||
// This callback gets events whenever the connection state changes.
|
||||
// Don't send if the errChannel has already been closed.
|
||||
select {
|
||||
case <-errCh:
|
||||
return
|
||||
default:
|
||||
errCh <- err
|
||||
close(errCh)
|
||||
}
|
||||
})
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create pq listener: %w", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
pgPubsub := &pgPubsub{
|
||||
db: database,
|
||||
pgListener: listener,
|
||||
listeners: make(map[string]map[string]Listener),
|
||||
}
|
||||
go pgPubsub.listen(ctx)
|
||||
|
||||
return pgPubsub, nil
|
||||
}
|
63
coderd/database/pubsub_memory.go
Normal file
63
coderd/database/pubsub_memory.go
Normal file
@ -0,0 +1,63 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// memoryPubsub is an in-memory Pubsub implementation.
|
||||
type memoryPubsub struct {
|
||||
mut sync.RWMutex
|
||||
listeners map[string]map[uuid.UUID]Listener
|
||||
}
|
||||
|
||||
func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
|
||||
var listeners map[uuid.UUID]Listener
|
||||
var ok bool
|
||||
if listeners, ok = m.listeners[event]; !ok {
|
||||
listeners = map[uuid.UUID]Listener{}
|
||||
m.listeners[event] = listeners
|
||||
}
|
||||
var id uuid.UUID
|
||||
for {
|
||||
id = uuid.New()
|
||||
if _, ok = listeners[id]; !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
listeners[id] = listener
|
||||
return func() {
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
listeners := m.listeners[event]
|
||||
delete(listeners, id)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *memoryPubsub) Publish(event string, message []byte) error {
|
||||
m.mut.RLock()
|
||||
defer m.mut.RUnlock()
|
||||
listeners, ok := m.listeners[event]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
for _, listener := range listeners {
|
||||
listener(context.Background(), message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*memoryPubsub) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPubsubInMemory() Pubsub {
|
||||
return &memoryPubsub{
|
||||
listeners: make(map[string]map[uuid.UUID]Listener),
|
||||
}
|
||||
}
|
35
coderd/database/pubsub_memory_test.go
Normal file
35
coderd/database/pubsub_memory_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
)
|
||||
|
||||
func TestPubsubMemory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Memory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
event := "test"
|
||||
data := "testing"
|
||||
messageChannel := make(chan []byte)
|
||||
cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
go func() {
|
||||
err = pubsub.Publish(event, []byte(data))
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
}
|
70
coderd/database/pubsub_test.go
Normal file
70
coderd/database/pubsub_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
//go:build linux
|
||||
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
)
|
||||
|
||||
func TestPubsub(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("Postgres", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
connectionURL, close, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer close()
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
event := "test"
|
||||
data := "testing"
|
||||
messageChannel := make(chan []byte)
|
||||
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
go func() {
|
||||
err = pubsub.Publish(event, []byte(data))
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
|
||||
t.Run("PostgresCloseCancel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
connectionURL, close, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
defer close()
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
cancelFunc()
|
||||
})
|
||||
}
|
84
coderd/database/querier.go
Normal file
84
coderd/database/querier.go
Normal file
@ -0,0 +1,84 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type querier interface {
|
||||
AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error)
|
||||
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
GetFileByHash(ctx context.Context, hash string) (File, error)
|
||||
GetOrganizationByID(ctx context.Context, id string) (Organization, error)
|
||||
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
|
||||
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
|
||||
GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error)
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error)
|
||||
GetParameterValuesByScope(ctx context.Context, arg GetParameterValuesByScopeParams) ([]ParameterValue, error)
|
||||
GetProjectByID(ctx context.Context, id uuid.UUID) (Project, error)
|
||||
GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error)
|
||||
GetProjectVersionByID(ctx context.Context, id uuid.UUID) (ProjectVersion, error)
|
||||
GetProjectVersionByJobID(ctx context.Context, jobID uuid.UUID) (ProjectVersion, error)
|
||||
GetProjectVersionByProjectIDAndName(ctx context.Context, arg GetProjectVersionByProjectIDAndNameParams) (ProjectVersion, error)
|
||||
GetProjectVersionsByProjectID(ctx context.Context, dollar_1 uuid.UUID) ([]ProjectVersion, error)
|
||||
GetProjectsByIDs(ctx context.Context, ids []uuid.UUID) ([]Project, error)
|
||||
GetProjectsByOrganization(ctx context.Context, arg GetProjectsByOrganizationParams) ([]Project, error)
|
||||
GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) (ProvisionerDaemon, error)
|
||||
GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error)
|
||||
GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
|
||||
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
|
||||
GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error)
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id string) (User, error)
|
||||
GetUserCount(ctx context.Context) (int64, error)
|
||||
GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error)
|
||||
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
|
||||
GetWorkspaceAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (WorkspaceAgent, error)
|
||||
GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (WorkspaceBuild, error)
|
||||
GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (WorkspaceBuild, error)
|
||||
GetWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
GetWorkspaceBuildByWorkspaceIDAndName(ctx context.Context, arg GetWorkspaceBuildByWorkspaceIDAndNameParams) (WorkspaceBuild, error)
|
||||
GetWorkspaceBuildByWorkspaceIDWithoutAfter(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
|
||||
GetWorkspaceBuildsByWorkspaceIDsWithoutAfter(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error)
|
||||
GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error)
|
||||
GetWorkspaceByUserIDAndName(ctx context.Context, arg GetWorkspaceByUserIDAndNameParams) (Workspace, error)
|
||||
GetWorkspaceOwnerCountsByProjectIDs(ctx context.Context, ids []uuid.UUID) ([]GetWorkspaceOwnerCountsByProjectIDsRow, error)
|
||||
GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (WorkspaceResource, error)
|
||||
GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]WorkspaceResource, error)
|
||||
GetWorkspacesByProjectID(ctx context.Context, arg GetWorkspacesByProjectIDParams) ([]Workspace, error)
|
||||
GetWorkspacesByUserID(ctx context.Context, arg GetWorkspacesByUserIDParams) ([]Workspace, error)
|
||||
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
|
||||
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
|
||||
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
|
||||
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
|
||||
InsertParameterSchema(ctx context.Context, arg InsertParameterSchemaParams) (ParameterSchema, error)
|
||||
InsertParameterValue(ctx context.Context, arg InsertParameterValueParams) (ParameterValue, error)
|
||||
InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error)
|
||||
InsertProjectVersion(ctx context.Context, arg InsertProjectVersionParams) (ProjectVersion, error)
|
||||
InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error)
|
||||
InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error)
|
||||
InsertProvisionerJobLogs(ctx context.Context, arg InsertProvisionerJobLogsParams) ([]ProvisionerJobLog, error)
|
||||
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
|
||||
InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error)
|
||||
InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error)
|
||||
InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error)
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateProjectActiveVersionByID(ctx context.Context, arg UpdateProjectActiveVersionByIDParams) error
|
||||
UpdateProjectDeletedByID(ctx context.Context, arg UpdateProjectDeletedByIDParams) error
|
||||
UpdateProjectVersionByID(ctx context.Context, arg UpdateProjectVersionByIDParams) error
|
||||
UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error
|
||||
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
|
||||
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
|
||||
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error
|
||||
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
|
||||
UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWorkspaceBuildByIDParams) error
|
||||
UpdateWorkspaceDeletedByID(ctx context.Context, arg UpdateWorkspaceDeletedByIDParams) error
|
||||
}
|
||||
|
||||
var _ querier = (*sqlQuerier)(nil)
|
791
coderd/database/query.sql
Normal file
791
coderd/database/query.sql
Normal file
@ -0,0 +1,791 @@
|
||||
-- Database queries are generated using sqlc. See:
|
||||
-- https://docs.sqlc.dev/en/latest/tutorials/getting-started-postgresql.html
|
||||
--
|
||||
-- Run "make gen" to generate models and query functions.
|
||||
;
|
||||
|
||||
-- Acquires the lock for a single job that isn't started, completed,
|
||||
-- canceled, and that matches an array of provisioner types.
|
||||
--
|
||||
-- SKIP LOCKED is used to jump over locked rows. This prevents
|
||||
-- multiple provisioners from acquiring the same jobs. See:
|
||||
-- https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE
|
||||
-- name: AcquireProvisionerJob :one
|
||||
UPDATE
|
||||
provisioner_jobs
|
||||
SET
|
||||
started_at = @started_at,
|
||||
updated_at = @started_at,
|
||||
worker_id = @worker_id
|
||||
WHERE
|
||||
id = (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
provisioner_jobs AS nested
|
||||
WHERE
|
||||
nested.started_at IS NULL
|
||||
AND nested.canceled_at IS NULL
|
||||
AND nested.completed_at IS NULL
|
||||
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
|
||||
ORDER BY
|
||||
nested.created_at FOR
|
||||
UPDATE
|
||||
SKIP LOCKED
|
||||
LIMIT
|
||||
1
|
||||
) RETURNING *;
|
||||
|
||||
-- name: DeleteParameterValueByID :exec
|
||||
DELETE FROM
|
||||
parameter_values
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetAPIKeyByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
api_keys
|
||||
WHERE
|
||||
id = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetFileByHash :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
files
|
||||
WHERE
|
||||
hash = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetUserByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
id = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetUserByEmailOrUsername :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
LOWER(username) = LOWER(@username)
|
||||
OR email = @email
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
users;
|
||||
|
||||
-- name: GetOrganizationByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organizations
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetOrganizationByName :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organizations
|
||||
WHERE
|
||||
LOWER(name) = LOWER(@name)
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetOrganizationsByUserID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organizations
|
||||
WHERE
|
||||
id = (
|
||||
SELECT
|
||||
organization_id
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
user_id = $1
|
||||
);
|
||||
|
||||
-- name: GetOrganizationMemberByUserID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND user_id = $2
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetParameterValuesByScope :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
parameter_values
|
||||
WHERE
|
||||
scope = $1
|
||||
AND scope_id = $2;
|
||||
|
||||
-- name: GetParameterValueByScopeAndName :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
parameter_values
|
||||
WHERE
|
||||
scope = $1
|
||||
AND scope_id = $2
|
||||
AND name = $3
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetProjectByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
projects
|
||||
WHERE
|
||||
id = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetProjectsByIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
projects
|
||||
WHERE
|
||||
id = ANY(@ids :: uuid [ ]);
|
||||
|
||||
-- name: GetProjectByOrganizationAndName :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
projects
|
||||
WHERE
|
||||
organization_id = @organization_id
|
||||
AND deleted = @deleted
|
||||
AND LOWER(name) = LOWER(@name)
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetProjectsByOrganization :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
projects
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND deleted = $2;
|
||||
|
||||
-- name: GetParameterSchemasByJobID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
parameter_schemas
|
||||
WHERE
|
||||
job_id = $1;
|
||||
|
||||
-- name: GetProjectVersionsByProjectID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
project_versions
|
||||
WHERE
|
||||
project_id = $1 :: uuid;
|
||||
|
||||
-- name: GetProjectVersionByJobID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
project_versions
|
||||
WHERE
|
||||
job_id = $1;
|
||||
|
||||
-- name: GetProjectVersionByProjectIDAndName :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
project_versions
|
||||
WHERE
|
||||
project_id = $1
|
||||
AND name = $2;
|
||||
|
||||
-- name: GetProjectVersionByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
project_versions
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetProvisionerLogsByIDBetween :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
provisioner_job_logs
|
||||
WHERE
|
||||
job_id = @job_id
|
||||
AND (
|
||||
created_at >= @created_after
|
||||
OR created_at <= @created_before
|
||||
)
|
||||
ORDER BY
|
||||
created_at;
|
||||
|
||||
-- name: GetProvisionerDaemonByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
provisioner_daemons
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetProvisionerDaemons :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
provisioner_daemons;
|
||||
|
||||
-- name: GetWorkspaceAgentByAuthToken :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
auth_token = $1
|
||||
ORDER BY
|
||||
created_at DESC;
|
||||
|
||||
-- name: GetWorkspaceAgentByInstanceID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
auth_instance_id = @auth_instance_id :: text
|
||||
ORDER BY
|
||||
created_at DESC;
|
||||
|
||||
-- name: GetProvisionerJobByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
provisioner_jobs
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetProvisionerJobsByIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
provisioner_jobs
|
||||
WHERE
|
||||
id = ANY(@ids :: uuid [ ]);
|
||||
|
||||
-- name: GetWorkspaceByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspaces
|
||||
WHERE
|
||||
id = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetWorkspacesByProjectID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspaces
|
||||
WHERE
|
||||
project_id = $1
|
||||
AND deleted = $2;
|
||||
|
||||
-- name: GetWorkspacesByUserID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspaces
|
||||
WHERE
|
||||
owner_id = $1
|
||||
AND deleted = $2;
|
||||
|
||||
-- name: GetWorkspaceByUserIDAndName :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspaces
|
||||
WHERE
|
||||
owner_id = @owner_id
|
||||
AND deleted = @deleted
|
||||
AND LOWER(name) = LOWER(@name);
|
||||
|
||||
-- name: GetWorkspaceOwnerCountsByProjectIDs :many
|
||||
SELECT
|
||||
project_id,
|
||||
COUNT(DISTINCT owner_id)
|
||||
FROM
|
||||
workspaces
|
||||
WHERE
|
||||
project_id = ANY(@ids :: uuid [ ])
|
||||
GROUP BY
|
||||
project_id,
|
||||
owner_id;
|
||||
|
||||
-- name: GetWorkspaceBuildByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
id = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetWorkspaceBuildByJobID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
job_id = $1
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetWorkspaceBuildByWorkspaceIDAndName :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
workspace_id = $1
|
||||
AND name = $2;
|
||||
|
||||
-- name: GetWorkspaceBuildByWorkspaceID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
workspace_id = $1;
|
||||
|
||||
-- name: GetWorkspaceBuildByWorkspaceIDWithoutAfter :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
workspace_id = $1
|
||||
AND after_id IS NULL
|
||||
LIMIT
|
||||
1;
|
||||
|
||||
-- name: GetWorkspaceBuildsByWorkspaceIDsWithoutAfter :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_builds
|
||||
WHERE
|
||||
workspace_id = ANY(@ids :: uuid [ ])
|
||||
AND after_id IS NULL;
|
||||
|
||||
-- name: GetWorkspaceResourceByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_resources
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetWorkspaceResourcesByJobID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_resources
|
||||
WHERE
|
||||
job_id = $1;
|
||||
|
||||
-- name: GetWorkspaceAgentByResourceID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
resource_id = $1;
|
||||
|
||||
-- name: InsertAPIKey :one
|
||||
INSERT INTO
|
||||
api_keys (
|
||||
id,
|
||||
hashed_secret,
|
||||
user_id,
|
||||
application,
|
||||
name,
|
||||
last_used,
|
||||
expires_at,
|
||||
created_at,
|
||||
updated_at,
|
||||
login_type,
|
||||
oidc_access_token,
|
||||
oidc_refresh_token,
|
||||
oidc_id_token,
|
||||
oidc_expiry,
|
||||
devurl_token
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15
|
||||
) RETURNING *;
|
||||
|
||||
-- name: InsertFile :one
|
||||
INSERT INTO
|
||||
files (hash, created_at, created_by, mimetype, data)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING *;
|
||||
|
||||
-- name: InsertProvisionerJobLogs :many
|
||||
INSERT INTO
|
||||
provisioner_job_logs
|
||||
SELECT
|
||||
unnest(@id :: uuid [ ]) AS id,
|
||||
@job_id :: uuid AS job_id,
|
||||
unnest(@created_at :: timestamptz [ ]) AS created_at,
|
||||
unnest(@source :: log_source [ ]) as source,
|
||||
unnest(@level :: log_level [ ]) as level,
|
||||
unnest(@output :: varchar(1024) [ ]) as output RETURNING *;
|
||||
|
||||
-- name: InsertOrganization :one
|
||||
INSERT INTO
|
||||
organizations (id, name, description, created_at, updated_at)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING *;
|
||||
|
||||
-- name: InsertOrganizationMember :one
|
||||
INSERT INTO
|
||||
organization_members (
|
||||
organization_id,
|
||||
user_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
roles
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING *;
|
||||
|
||||
-- name: InsertParameterValue :one
|
||||
INSERT INTO
|
||||
parameter_values (
|
||||
id,
|
||||
name,
|
||||
created_at,
|
||||
updated_at,
|
||||
scope,
|
||||
scope_id,
|
||||
source_scheme,
|
||||
source_value,
|
||||
destination_scheme
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;
|
||||
|
||||
-- name: InsertProject :one
|
||||
INSERT INTO
|
||||
projects (
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
organization_id,
|
||||
name,
|
||||
provisioner,
|
||||
active_version_id
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7) RETURNING *;
|
||||
|
||||
-- name: InsertWorkspaceResource :one
|
||||
INSERT INTO
|
||||
workspace_resources (
|
||||
id,
|
||||
created_at,
|
||||
job_id,
|
||||
transition,
|
||||
address,
|
||||
type,
|
||||
name,
|
||||
agent_id
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
|
||||
|
||||
-- name: InsertProjectVersion :one
|
||||
INSERT INTO
|
||||
project_versions (
|
||||
id,
|
||||
project_id,
|
||||
organization_id,
|
||||
created_at,
|
||||
updated_at,
|
||||
name,
|
||||
description,
|
||||
job_id
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
|
||||
|
||||
-- name: InsertParameterSchema :one
|
||||
INSERT INTO
|
||||
parameter_schemas (
|
||||
id,
|
||||
created_at,
|
||||
job_id,
|
||||
name,
|
||||
description,
|
||||
default_source_scheme,
|
||||
default_source_value,
|
||||
allow_override_source,
|
||||
default_destination_scheme,
|
||||
allow_override_destination,
|
||||
default_refresh,
|
||||
redisplay_value,
|
||||
validation_error,
|
||||
validation_condition,
|
||||
validation_type_system,
|
||||
validation_value_type
|
||||
)
|
||||
VALUES
|
||||
(
|
||||
$1,
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
$9,
|
||||
$10,
|
||||
$11,
|
||||
$12,
|
||||
$13,
|
||||
$14,
|
||||
$15,
|
||||
$16
|
||||
) RETURNING *;
|
||||
|
||||
-- name: InsertProvisionerDaemon :one
|
||||
INSERT INTO
|
||||
provisioner_daemons (id, created_at, organization_id, name, provisioners)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING *;
|
||||
|
||||
-- name: InsertProvisionerJob :one
|
||||
INSERT INTO
|
||||
provisioner_jobs (
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
organization_id,
|
||||
initiator_id,
|
||||
provisioner,
|
||||
storage_method,
|
||||
storage_source,
|
||||
type,
|
||||
input
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
|
||||
|
||||
-- name: InsertUser :one
|
||||
INSERT INTO
|
||||
users (
|
||||
id,
|
||||
email,
|
||||
name,
|
||||
login_type,
|
||||
revoked,
|
||||
hashed_password,
|
||||
created_at,
|
||||
updated_at,
|
||||
username
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, false, $5, $6, $7, $8) RETURNING *;
|
||||
|
||||
-- name: InsertWorkspace :one
|
||||
INSERT INTO
|
||||
workspaces (
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
owner_id,
|
||||
project_id,
|
||||
name
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6) RETURNING *;
|
||||
|
||||
-- name: InsertWorkspaceAgent :one
|
||||
INSERT INTO
|
||||
workspace_agents (
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
resource_id,
|
||||
auth_token,
|
||||
auth_instance_id,
|
||||
environment_variables,
|
||||
startup_script,
|
||||
instance_metadata,
|
||||
resource_metadata
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
|
||||
|
||||
-- name: InsertWorkspaceBuild :one
|
||||
INSERT INTO
|
||||
workspace_builds (
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
workspace_id,
|
||||
project_version_id,
|
||||
before_id,
|
||||
name,
|
||||
transition,
|
||||
initiator,
|
||||
job_id,
|
||||
provisioner_state
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *;
|
||||
|
||||
-- name: UpdateAPIKeyByID :exec
|
||||
UPDATE
|
||||
api_keys
|
||||
SET
|
||||
last_used = $2,
|
||||
expires_at = $3,
|
||||
oidc_access_token = $4,
|
||||
oidc_refresh_token = $5,
|
||||
oidc_expiry = $6
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateProjectActiveVersionByID :exec
|
||||
UPDATE
|
||||
projects
|
||||
SET
|
||||
active_version_id = $2
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateProjectDeletedByID :exec
|
||||
UPDATE
|
||||
projects
|
||||
SET
|
||||
deleted = $2
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateProjectVersionByID :exec
|
||||
UPDATE
|
||||
project_versions
|
||||
SET
|
||||
project_id = $2,
|
||||
updated_at = $3
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateProvisionerDaemonByID :exec
|
||||
UPDATE
|
||||
provisioner_daemons
|
||||
SET
|
||||
updated_at = $2,
|
||||
provisioners = $3
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateProvisionerJobByID :exec
|
||||
UPDATE
|
||||
provisioner_jobs
|
||||
SET
|
||||
updated_at = $2
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateProvisionerJobWithCancelByID :exec
|
||||
UPDATE
|
||||
provisioner_jobs
|
||||
SET
|
||||
canceled_at = $2
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateProvisionerJobWithCompleteByID :exec
|
||||
UPDATE
|
||||
provisioner_jobs
|
||||
SET
|
||||
updated_at = $2,
|
||||
completed_at = $3,
|
||||
canceled_at = $4,
|
||||
error = $5
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateWorkspaceDeletedByID :exec
|
||||
UPDATE
|
||||
workspaces
|
||||
SET
|
||||
deleted = $2
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateWorkspaceAgentConnectionByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
SET
|
||||
first_connected_at = $2,
|
||||
last_connected_at = $3,
|
||||
disconnected_at = $4
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateWorkspaceBuildByID :exec
|
||||
UPDATE
|
||||
workspace_builds
|
||||
SET
|
||||
updated_at = $2,
|
||||
after_id = $3,
|
||||
provisioner_state = $4
|
||||
WHERE
|
||||
id = $1;
|
2718
coderd/database/query.sql.go
Normal file
2718
coderd/database/query.sql.go
Normal file
File diff suppressed because it is too large
Load Diff
29
coderd/database/sqlc.yaml
Normal file
29
coderd/database/sqlc.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
# sqlc is used to generate types from sql schema language.
|
||||
# It was chosen to ensure type-safety when interacting with
|
||||
# the database.
|
||||
version: "1"
|
||||
packages:
|
||||
- name: "database"
|
||||
path: "."
|
||||
queries: "./query.sql"
|
||||
schema: "./dump.sql"
|
||||
engine: "postgresql"
|
||||
emit_interface: true
|
||||
emit_json_tags: true
|
||||
emit_db_tags: true
|
||||
# We replace the generated db file with our own
|
||||
# to add support for transactions. This file is
|
||||
# deleted after generation.
|
||||
output_db_file_name: db_tmp.go
|
||||
overrides:
|
||||
- db_type: citext
|
||||
go_type: string
|
||||
rename:
|
||||
api_key: APIKey
|
||||
login_type_oidc: LoginTypeOIDC
|
||||
oidc_access_token: OIDCAccessToken
|
||||
oidc_expiry: OIDCExpiry
|
||||
oidc_id_token: OIDCIDToken
|
||||
oidc_refresh_token: OIDCRefreshToken
|
||||
parameter_type_system_hcl: ParameterTypeSystemHCL
|
||||
userstatus: UserStatus
|
8
coderd/database/time.go
Normal file
8
coderd/database/time.go
Normal file
@ -0,0 +1,8 @@
|
||||
package database
|
||||
|
||||
import "time"
|
||||
|
||||
// Now returns a standardized timezone used for database resources.
|
||||
func Now() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
@ -12,10 +12,10 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func (api *api) postFile(rw http.ResponseWriter, r *http.Request) {
|
||||
|
117
coderd/httpapi/httpapi.go
Normal file
117
coderd/httpapi/httpapi.go
Normal file
@ -0,0 +1,117 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var (
|
||||
validate *validator.Validate
|
||||
usernameRegex = regexp.MustCompile("^[a-zA-Z0-9]+(?:-[a-zA-Z0-9]+)*$")
|
||||
)
|
||||
|
||||
// This init is used to create a validator and register validation-specific
|
||||
// functionality for the HTTP API.
|
||||
//
|
||||
// A single validator instance is used, because it caches struct parsing.
|
||||
func init() {
|
||||
validate = validator.New()
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
err := validate.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||
f := fl.Field().Interface()
|
||||
str, ok := f.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if len(str) > 32 {
|
||||
return false
|
||||
}
|
||||
if len(str) < 1 {
|
||||
return false
|
||||
}
|
||||
return usernameRegex.MatchString(str)
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Response represents a generic HTTP response.
|
||||
type Response struct {
|
||||
Message string `json:"message" validate:"required"`
|
||||
Errors []Error `json:"errors,omitempty" validate:"required"`
|
||||
}
|
||||
|
||||
// Error represents a scoped error to a user input.
|
||||
type Error struct {
|
||||
Field string `json:"field" validate:"required"`
|
||||
Code string `json:"code" validate:"required"`
|
||||
}
|
||||
|
||||
// Write outputs a standardized format to an HTTP response body.
|
||||
func Write(rw http.ResponseWriter, status int, response Response) {
|
||||
buf := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buf)
|
||||
enc.SetEscapeHTML(true)
|
||||
err := enc.Encode(response)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
rw.WriteHeader(status)
|
||||
_, err = rw.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Read decodes JSON from the HTTP request into the value provided.
|
||||
// It uses go-validator to validate the incoming request body.
|
||||
func Read(rw http.ResponseWriter, r *http.Request, value interface{}) bool {
|
||||
err := json.NewDecoder(r.Body).Decode(value)
|
||||
if err != nil {
|
||||
Write(rw, http.StatusBadRequest, Response{
|
||||
Message: fmt.Sprintf("read body: %s", err.Error()),
|
||||
})
|
||||
return false
|
||||
}
|
||||
err = validate.Struct(value)
|
||||
var validationErrors validator.ValidationErrors
|
||||
if errors.As(err, &validationErrors) {
|
||||
apiErrors := make([]Error, 0, len(validationErrors))
|
||||
for _, validationError := range validationErrors {
|
||||
apiErrors = append(apiErrors, Error{
|
||||
Field: validationError.Field(),
|
||||
Code: validationError.Tag(),
|
||||
})
|
||||
}
|
||||
Write(rw, http.StatusBadRequest, Response{
|
||||
Message: "Validation failed",
|
||||
Errors: apiErrors,
|
||||
})
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
Write(rw, http.StatusInternalServerError, Response{
|
||||
Message: fmt.Sprintf("validation: %s", err.Error()),
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
144
coderd/httpapi/httpapi_test.go
Normal file
144
coderd/httpapi/httpapi_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
package httpapi_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NoErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
httpapi.Write(rw, http.StatusOK, httpapi.Response{
|
||||
Message: "wow",
|
||||
})
|
||||
var m map[string]interface{}
|
||||
err := json.NewDecoder(rw.Body).Decode(&m)
|
||||
require.NoError(t, err)
|
||||
_, ok := m["errors"]
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("EmptyStruct", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
|
||||
v := struct{}{}
|
||||
require.True(t, httpapi.Read(rw, r, &v))
|
||||
})
|
||||
|
||||
t.Run("NoBody", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", nil)
|
||||
var v json.RawMessage
|
||||
require.False(t, httpapi.Read(rw, r, v))
|
||||
})
|
||||
|
||||
t.Run("Validate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
type toValidate struct {
|
||||
Value string `json:"value" validate:"required"`
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", bytes.NewBufferString(`{"value":"hi"}`))
|
||||
|
||||
var validate toValidate
|
||||
require.True(t, httpapi.Read(rw, r, &validate))
|
||||
require.Equal(t, validate.Value, "hi")
|
||||
})
|
||||
|
||||
t.Run("ValidateFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
type toValidate struct {
|
||||
Value string `json:"value" validate:"required"`
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
|
||||
|
||||
var validate toValidate
|
||||
require.False(t, httpapi.Read(rw, r, &validate))
|
||||
var v httpapi.Response
|
||||
err := json.NewDecoder(rw.Body).Decode(&v)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v.Errors, 1)
|
||||
require.Equal(t, v.Errors[0].Field, "value")
|
||||
require.Equal(t, v.Errors[0].Code, "required")
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadUsername(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Tests whether usernames are valid or not.
|
||||
testCases := []struct {
|
||||
Username string
|
||||
Valid bool
|
||||
}{
|
||||
{"1", true},
|
||||
{"12", true},
|
||||
{"123", true},
|
||||
{"12345678901234567890", true},
|
||||
{"123456789012345678901", true},
|
||||
{"a", true},
|
||||
{"a1", true},
|
||||
{"a1b2", true},
|
||||
{"a1b2c3d4e5f6g7h8i9j0", true},
|
||||
{"a1b2c3d4e5f6g7h8i9j0k", true},
|
||||
{"aa", true},
|
||||
{"abc", true},
|
||||
{"abcdefghijklmnopqrst", true},
|
||||
{"abcdefghijklmnopqrstu", true},
|
||||
{"wow-test", true},
|
||||
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{" a", false},
|
||||
{" a ", false},
|
||||
{" 1", false},
|
||||
{"1 ", false},
|
||||
{" aa", false},
|
||||
{"aa ", false},
|
||||
{" 12", false},
|
||||
{"12 ", false},
|
||||
{" a1", false},
|
||||
{"a1 ", false},
|
||||
{" abcdefghijklmnopqrstu", false},
|
||||
{"abcdefghijklmnopqrstu ", false},
|
||||
{" 123456789012345678901", false},
|
||||
{" a1b2c3d4e5f6g7h8i9j0k", false},
|
||||
{"a1b2c3d4e5f6g7h8i9j0k ", false},
|
||||
{"bananas_wow", false},
|
||||
{"test--now", false},
|
||||
|
||||
{"123456789012345678901234567890123", false},
|
||||
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
|
||||
{"123456789012345678901234567890123123456789012345678901234567890123", false},
|
||||
}
|
||||
type toValidate struct {
|
||||
Username string `json:"username" validate:"username"`
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.Username, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := httptest.NewRecorder()
|
||||
data, err := json.Marshal(toValidate{testCase.Username})
|
||||
require.NoError(t, err)
|
||||
r := httptest.NewRequest("POST", "/", bytes.NewBuffer(data))
|
||||
|
||||
var validate toValidate
|
||||
require.Equal(t, httpapi.Read(rw, r, &validate), testCase.Valid)
|
||||
})
|
||||
}
|
||||
}
|
167
coderd/httpmw/apikey.go
Normal file
167
coderd/httpmw/apikey.go
Normal file
@ -0,0 +1,167 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
// AuthCookie represents the name of the cookie the API key is stored in.
|
||||
const AuthCookie = "session_token"
|
||||
|
||||
// OAuth2Config contains a subset of functions exposed from oauth2.Config.
|
||||
// It is abstracted for simple testing.
|
||||
type OAuth2Config interface {
|
||||
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
|
||||
}
|
||||
|
||||
type apiKeyContextKey struct{}
|
||||
|
||||
// APIKey returns the API key from the ExtractAPIKey handler.
|
||||
func APIKey(r *http.Request) database.APIKey {
|
||||
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||
if !ok {
|
||||
panic("developer error: apikey middleware not provided")
|
||||
}
|
||||
return apiKey
|
||||
}
|
||||
|
||||
// ExtractAPIKey requires authentication using a valid API key.
|
||||
// It handles extending an API key if it comes close to expiry,
|
||||
// updating the last used time in the database.
|
||||
func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(AuthCookie)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("%q cookie must be provided", AuthCookie),
|
||||
})
|
||||
return
|
||||
}
|
||||
parts := strings.Split(cookie.Value, "-")
|
||||
// APIKeys are formatted: ID-SECRET
|
||||
if len(parts) != 2 {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("invalid %q cookie api key format", AuthCookie),
|
||||
})
|
||||
return
|
||||
}
|
||||
keyID := parts[0]
|
||||
keySecret := parts[1]
|
||||
// Ensuring key lengths are valid.
|
||||
if len(keyID) != 10 {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("invalid %q cookie api key id", AuthCookie),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(keySecret) != 22 {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("invalid %q cookie api key secret", AuthCookie),
|
||||
})
|
||||
return
|
||||
}
|
||||
key, err := db.GetAPIKeyByID(r.Context(), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "api key is invalid",
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get api key by id: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
// Checking to see if the secret is valid.
|
||||
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "api key secret is invalid",
|
||||
})
|
||||
return
|
||||
}
|
||||
now := database.Now()
|
||||
// Tracks if the API key has properties updated!
|
||||
changed := false
|
||||
|
||||
if key.LoginType == database.LoginTypeOIDC {
|
||||
// Check if the OIDC token is expired!
|
||||
if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() {
|
||||
// If it is, let's refresh it from the provided config!
|
||||
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||
AccessToken: key.OIDCAccessToken,
|
||||
RefreshToken: key.OIDCRefreshToken,
|
||||
Expiry: key.OIDCExpiry,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("couldn't refresh expired oauth token: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
key.OIDCAccessToken = token.AccessToken
|
||||
key.OIDCRefreshToken = token.RefreshToken
|
||||
key.OIDCExpiry = token.Expiry
|
||||
key.ExpiresAt = token.Expiry
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
// Checking if the key is expired.
|
||||
if key.ExpiresAt.Before(now) {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("api key expired at %q", key.ExpiresAt.String()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Only update LastUsed once an hour to prevent database spam.
|
||||
if now.Sub(key.LastUsed) > time.Hour {
|
||||
key.LastUsed = now
|
||||
changed = true
|
||||
}
|
||||
// Only update the ExpiresAt once an hour to prevent database spam.
|
||||
// We extend the ExpiresAt to reduce reauthentication.
|
||||
apiKeyLifetime := 24 * time.Hour
|
||||
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
|
||||
key.ExpiresAt = now.Add(apiKeyLifetime)
|
||||
changed = true
|
||||
}
|
||||
|
||||
if changed {
|
||||
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
LastUsed: key.LastUsed,
|
||||
OIDCAccessToken: key.OIDCAccessToken,
|
||||
OIDCRefreshToken: key.OIDCRefreshToken,
|
||||
OIDCExpiry: key.OIDCExpiry,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("api key couldn't update: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), apiKeyContextKey{}, key)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
375
coderd/httpmw/apikey_test.go
Normal file
375
coderd/httpmw/apikey_test.go
Normal file
@ -0,0 +1,375 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func randomAPIKeyParts() (id string, secret string) {
|
||||
id, _ = cryptorand.String(10)
|
||||
secret, _ = cryptorand.String(22)
|
||||
return id, secret
|
||||
}
|
||||
|
||||
func TestAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
// Only called if the API key passes through the handler.
|
||||
httpapi.Write(rw, http.StatusOK, httpapi.Response{
|
||||
Message: "it worked!",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("NoCookie", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidFormat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: "test-wow-hello",
|
||||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidIDLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: "test-wow",
|
||||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidSecretLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: "testtestid-wow",
|
||||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("InvalidSecret", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
// Use a different secret so they don't match!
|
||||
hashed := sha256.Sum256([]byte("differentsecret"))
|
||||
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
// Checks that it exists on the context!
|
||||
_ = httpmw.APIKey(r)
|
||||
httpapi.Write(rw, http.StatusOK, httpapi.Response{
|
||||
Message: "it worked!",
|
||||
})
|
||||
})).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("ValidUpdateLastUsed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("ValidUpdateExpiry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("OIDCNotExpired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("OIDCRefresh", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
LastUsed: database.Now(),
|
||||
OIDCExpiry: database.Now().AddDate(0, 0, -1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "wow",
|
||||
RefreshToken: "moo",
|
||||
Expiry: database.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
httpmw.ExtractAPIKey(db, &oauth2Config{
|
||||
tokenSource: &oauth2TokenSource{
|
||||
token: func() (*oauth2.Token, error) {
|
||||
return token, nil
|
||||
},
|
||||
},
|
||||
})(successHandler).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
|
||||
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
|
||||
require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken)
|
||||
})
|
||||
}
|
||||
|
||||
type oauth2Config struct {
|
||||
tokenSource *oauth2TokenSource
|
||||
}
|
||||
|
||||
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
|
||||
return o.tokenSource
|
||||
}
|
||||
|
||||
type oauth2TokenSource struct {
|
||||
token func() (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
|
||||
return o.token()
|
||||
}
|
30
coderd/httpmw/httpmw.go
Normal file
30
coderd/httpmw/httpmw.go
Normal file
@ -0,0 +1,30 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
// parseUUID consumes a url parameter and parses it as a UUID.
|
||||
func parseUUID(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) {
|
||||
rawID := chi.URLParam(r, param)
|
||||
if rawID == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("%q must be provided", param),
|
||||
})
|
||||
return uuid.UUID{}, false
|
||||
}
|
||||
parsed, err := uuid.Parse(rawID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("%q must be a uuid", param),
|
||||
})
|
||||
return uuid.UUID{}, false
|
||||
}
|
||||
return parsed, true
|
||||
}
|
86
coderd/httpmw/organizationparam.go
Normal file
86
coderd/httpmw/organizationparam.go
Normal file
@ -0,0 +1,86 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type organizationParamContextKey struct{}
|
||||
type organizationMemberParamContextKey struct{}
|
||||
|
||||
// OrganizationParam returns the organization from the ExtractOrganizationParam handler.
|
||||
func OrganizationParam(r *http.Request) database.Organization {
|
||||
organization, ok := r.Context().Value(organizationParamContextKey{}).(database.Organization)
|
||||
if !ok {
|
||||
panic("developer error: organization param middleware not provided")
|
||||
}
|
||||
return organization
|
||||
}
|
||||
|
||||
// OrganizationMemberParam returns the organization membership that allowed the query
|
||||
// from the ExtractOrganizationParam handler.
|
||||
func OrganizationMemberParam(r *http.Request) database.OrganizationMember {
|
||||
organizationMember, ok := r.Context().Value(organizationMemberParamContextKey{}).(database.OrganizationMember)
|
||||
if !ok {
|
||||
panic("developer error: organization param middleware not provided")
|
||||
}
|
||||
return organizationMember
|
||||
}
|
||||
|
||||
// ExtractOrganizationParam grabs an organization and user membership from the "organization" URL parameter.
|
||||
// This middleware requires the API key middleware higher in the call stack for authentication.
|
||||
func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
organizationID := chi.URLParam(r, "organization")
|
||||
if organizationID == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "organization must be provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
organization, err := db.GetOrganizationByID(r.Context(), organizationID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: fmt.Sprintf("organization %q does not exist", organizationID),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get organization: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
apiKey := APIKey(r)
|
||||
organizationMember, err := db.GetOrganizationMemberByUserID(r.Context(), database.GetOrganizationMemberByUserIDParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: apiKey.UserID,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "not a member of the organization",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get organization member: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), organizationParamContextKey{}, organization)
|
||||
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, organizationMember)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
165
coderd/httpmw/organizationparam_test.go
Normal file
165
coderd/httpmw/organizationparam_test.go
Normal file
@ -0,0 +1,165 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestOrganizationParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.User) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
userID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
return r, user
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, _ = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, _ = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", "nothin")
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotInOrganization", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, _ = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: uuid.NewString(),
|
||||
Name: "test",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
db = databasefake.New()
|
||||
rw = httptest.NewRecorder()
|
||||
r, user = setupAuthentication(db)
|
||||
rtr = chi.NewRouter()
|
||||
)
|
||||
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: uuid.NewString(),
|
||||
Name: "test",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID)
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.OrganizationParam(r)
|
||||
_ = httpmw.OrganizationMemberParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
rtr.ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
53
coderd/httpmw/projectparam.go
Normal file
53
coderd/httpmw/projectparam.go
Normal file
@ -0,0 +1,53 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type projectParamContextKey struct{}
|
||||
|
||||
// ProjectParam returns the project from the ExtractProjectParam handler.
|
||||
func ProjectParam(r *http.Request) database.Project {
|
||||
project, ok := r.Context().Value(projectParamContextKey{}).(database.Project)
|
||||
if !ok {
|
||||
panic("developer error: project param middleware not provided")
|
||||
}
|
||||
return project
|
||||
}
|
||||
|
||||
// ExtractProjectParam grabs a project from the "project" URL parameter.
|
||||
func ExtractProjectParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
projectID, parsed := parseUUID(rw, r, "project")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
project, err := db.GetProjectByID(r.Context(), projectID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: fmt.Sprintf("project %q does not exist", projectID),
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get project: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), projectParamContextKey{}, project)
|
||||
chi.RouteContext(ctx).URLParams.Add("organization", project.OrganizationID)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
159
coderd/httpmw/projectparam_test.go
Normal file
159
coderd/httpmw/projectparam_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestProjectParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.Organization) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
userID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
orgID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: orgID,
|
||||
Name: "banana",
|
||||
Description: "wowie",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: orgID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, organization
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractProjectParam(db))
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setupAuthentication(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractProjectParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("project", uuid.NewString())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("BadUUID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractProjectParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("project", "not-a-uuid")
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Project", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractProjectParam(db),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.ProjectParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, org := setupAuthentication(db)
|
||||
project, err := db.InsertProject(context.Background(), database.InsertProjectParams{
|
||||
ID: uuid.New(),
|
||||
OrganizationID: org.ID,
|
||||
Name: "moo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("project", project.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
54
coderd/httpmw/projectversionparam.go
Normal file
54
coderd/httpmw/projectversionparam.go
Normal file
@ -0,0 +1,54 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type projectVersionParamContextKey struct{}
|
||||
|
||||
// ProjectVersionParam returns the project version from the ExtractProjectVersionParam handler.
|
||||
func ProjectVersionParam(r *http.Request) database.ProjectVersion {
|
||||
projectVersion, ok := r.Context().Value(projectVersionParamContextKey{}).(database.ProjectVersion)
|
||||
if !ok {
|
||||
panic("developer error: project version param middleware not provided")
|
||||
}
|
||||
return projectVersion
|
||||
}
|
||||
|
||||
// ExtractProjectVersionParam grabs project version from the "projectversion" URL parameter.
|
||||
func ExtractProjectVersionParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
projectVersionID, parsed := parseUUID(rw, r, "projectversion")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
projectVersion, err := db.GetProjectVersionByID(r.Context(), projectVersionID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: fmt.Sprintf("project version %q does not exist", projectVersionID),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get project version: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), projectVersionParamContextKey{}, projectVersion)
|
||||
chi.RouteContext(ctx).URLParams.Add("organization", projectVersion.OrganizationID)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
150
coderd/httpmw/projectversionparam_test.go
Normal file
150
coderd/httpmw/projectversionparam_test.go
Normal file
@ -0,0 +1,150 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestProjectVersionParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.Project) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
userID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
orgID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
|
||||
ID: orgID,
|
||||
Name: "banana",
|
||||
Description: "wowie",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: orgID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
project, err := db.InsertProject(context.Background(), database.InsertProjectParams{
|
||||
ID: uuid.New(),
|
||||
OrganizationID: organization.ID,
|
||||
Name: "moo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("organization", organization.Name)
|
||||
ctx.URLParams.Add("project", project.Name)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, project
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractProjectVersionParam(db))
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setupAuthentication(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractProjectVersionParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("projectversion", uuid.NewString())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("ProjectVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractProjectVersionParam(db),
|
||||
httpmw.ExtractOrganizationParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.ProjectVersionParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, project := setupAuthentication(db)
|
||||
projectVersion, err := db.InsertProjectVersion(context.Background(), database.InsertProjectVersionParams{
|
||||
ID: uuid.New(),
|
||||
OrganizationID: project.OrganizationID,
|
||||
Name: "moo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("projectversion", projectVersion.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
54
coderd/httpmw/userparam.go
Normal file
54
coderd/httpmw/userparam.go
Normal file
@ -0,0 +1,54 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type userParamContextKey struct{}
|
||||
|
||||
// UserParam returns the user from the ExtractUserParam handler.
|
||||
func UserParam(r *http.Request) database.User {
|
||||
user, ok := r.Context().Value(userParamContextKey{}).(database.User)
|
||||
if !ok {
|
||||
panic("developer error: user parameter middleware not provided")
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// ExtractUserParam extracts a user from an ID/username in the {user} URL parameter.
|
||||
func ExtractUserParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "user")
|
||||
if userID == "" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "user id or name must be provided",
|
||||
})
|
||||
return
|
||||
}
|
||||
apiKey := APIKey(r)
|
||||
if apiKey.UserID != userID && userID != "me" {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "getting non-personal users isn't supported yet",
|
||||
})
|
||||
return
|
||||
}
|
||||
user, err := db.GetUserByID(r.Context(), apiKey.UserID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get user: %s", err.Error()),
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userParamContextKey{}, user)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
104
coderd/httpmw/userparam_test.go
Normal file
104
coderd/httpmw/userparam_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
func TestUserParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) {
|
||||
var (
|
||||
db = databasefake.New()
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rw = httptest.NewRecorder()
|
||||
)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
_, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: "bananas",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: "bananas",
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return db, rw, r
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
r = returnedRequest
|
||||
})).ServeHTTP(rw, r)
|
||||
|
||||
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotMe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
r = returnedRequest
|
||||
})).ServeHTTP(rw, r)
|
||||
|
||||
routeContext := chi.NewRouteContext()
|
||||
routeContext.URLParams.Add("user", "ben")
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
|
||||
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Me", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, rw, r := setup(t)
|
||||
|
||||
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
|
||||
r = returnedRequest
|
||||
})).ServeHTTP(rw, r)
|
||||
|
||||
routeContext := chi.NewRouteContext()
|
||||
routeContext.URLParams.Add("user", "me")
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
|
||||
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.UserParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(rw, r)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
65
coderd/httpmw/workspaceagent.go
Normal file
65
coderd/httpmw/workspaceagent.go
Normal file
@ -0,0 +1,65 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type workspaceAgentContextKey struct{}
|
||||
|
||||
// WorkspaceAgent returns the workspace agent from the ExtractAgent handler.
|
||||
func WorkspaceAgent(r *http.Request) database.WorkspaceAgent {
|
||||
user, ok := r.Context().Value(workspaceAgentContextKey{}).(database.WorkspaceAgent)
|
||||
if !ok {
|
||||
panic("developer error: agent middleware not provided")
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// ExtractWorkspaceAgent requires authentication using a valid agent token.
|
||||
func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(AuthCookie)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: fmt.Sprintf("%q cookie must be provided", AuthCookie),
|
||||
})
|
||||
return
|
||||
}
|
||||
token, err := uuid.Parse(cookie.Value)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("parse token: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
agent, err := db.GetWorkspaceAgentByAuthToken(r.Context(), token)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "agent token is invalid",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get workspace agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), workspaceAgentContextKey{}, agent)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
73
coderd/httpmw/workspaceagent_test.go
Normal file
73
coderd/httpmw/workspaceagent_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
func TestWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setup := func(db database.Store) (*http.Request, uuid.UUID) {
|
||||
token := uuid.New()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: token.String(),
|
||||
})
|
||||
return r, token
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractWorkspaceAgent(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setup(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractWorkspaceAgent(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.WorkspaceAgent(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
r, token := setup(db)
|
||||
_, err := db.InsertWorkspaceAgent(context.Background(), database.InsertWorkspaceAgentParams{
|
||||
ID: uuid.New(),
|
||||
AuthToken: token,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
56
coderd/httpmw/workspacebuildparam.go
Normal file
56
coderd/httpmw/workspacebuildparam.go
Normal file
@ -0,0 +1,56 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type workspaceBuildParamContextKey struct{}
|
||||
|
||||
// WorkspaceBuildParam returns the workspace build from the ExtractWorkspaceBuildParam handler.
|
||||
func WorkspaceBuildParam(r *http.Request) database.WorkspaceBuild {
|
||||
workspaceBuild, ok := r.Context().Value(workspaceBuildParamContextKey{}).(database.WorkspaceBuild)
|
||||
if !ok {
|
||||
panic("developer error: workspace build param middleware not provided")
|
||||
}
|
||||
return workspaceBuild
|
||||
}
|
||||
|
||||
// ExtractWorkspaceBuildParam grabs workspace build from the "workspacebuild" URL parameter.
|
||||
func ExtractWorkspaceBuildParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
workspaceBuildID, parsed := parseUUID(rw, r, "workspacebuild")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
workspaceBuild, err := db.GetWorkspaceBuildByID(r.Context(), workspaceBuildID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: fmt.Sprintf("workspace build %q does not exist", workspaceBuildID),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get workspace build: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), workspaceBuildParamContextKey{}, workspaceBuild)
|
||||
// This injects the "workspace" parameter, because it's expected the consumer
|
||||
// will want to use the Workspace middleware to ensure the caller owns the workspace.
|
||||
chi.RouteContext(ctx).URLParams.Add("workspace", workspaceBuild.WorkspaceID.String())
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
134
coderd/httpmw/workspacebuildparam_test.go
Normal file
134
coderd/httpmw/workspacebuildparam_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestWorkspaceBuildParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.Workspace) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
userID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
|
||||
ID: uuid.New(),
|
||||
ProjectID: uuid.New(),
|
||||
OwnerID: user.ID,
|
||||
Name: "potato",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("user", userID)
|
||||
ctx.URLParams.Add("workspace", workspace.Name)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, workspace
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractWorkspaceBuildParam(db))
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setupAuthentication(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractWorkspaceBuildParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspacebuild", uuid.NewString())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceBuild", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractWorkspaceBuildParam(db),
|
||||
httpmw.ExtractWorkspaceParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.WorkspaceBuildParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, workspace := setupAuthentication(db)
|
||||
workspaceBuild, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: workspace.ID,
|
||||
Name: "moo",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspacebuild", workspaceBuild.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
59
coderd/httpmw/workspaceparam.go
Normal file
59
coderd/httpmw/workspaceparam.go
Normal file
@ -0,0 +1,59 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type workspaceParamContextKey struct{}
|
||||
|
||||
// WorkspaceParam returns the workspace from the ExtractWorkspaceParam handler.
|
||||
func WorkspaceParam(r *http.Request) database.Workspace {
|
||||
workspace, ok := r.Context().Value(workspaceParamContextKey{}).(database.Workspace)
|
||||
if !ok {
|
||||
panic("developer error: workspace param middleware not provided")
|
||||
}
|
||||
return workspace
|
||||
}
|
||||
|
||||
// ExtractWorkspaceParam grabs a workspace from the "workspace" URL parameter.
|
||||
func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
workspaceID, parsed := parseUUID(rw, r, "workspace")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
workspace, err := db.GetWorkspaceByID(r.Context(), workspaceID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: fmt.Sprintf("workspace %q does not exist", workspaceID),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get workspace: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := APIKey(r)
|
||||
if apiKey.UserID != workspace.OwnerID {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "getting non-personal workspaces isn't supported",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), workspaceParamContextKey{}, workspace)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
148
coderd/httpmw/workspaceparam_test.go
Normal file
148
coderd/httpmw/workspaceparam_test.go
Normal file
@ -0,0 +1,148 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestWorkspaceParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setup := func(db database.Store) (*http.Request, database.User) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
userID, err := cryptorand.String(16)
|
||||
require.NoError(t, err)
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("user", "me")
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, user
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractWorkspaceParam(db))
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setup(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractWorkspaceParam(db))
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setup(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspace", uuid.NewString())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NonPersonal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractWorkspaceParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setup(db)
|
||||
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
|
||||
ID: uuid.New(),
|
||||
OwnerID: "not-me",
|
||||
Name: "hello",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspace", workspace.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractWorkspaceParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.WorkspaceParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
r, user := setup(db)
|
||||
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
|
||||
ID: uuid.New(),
|
||||
OwnerID: user.ID,
|
||||
Name: "hello",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspace", workspace.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
76
coderd/httpmw/workspaceresourceparam.go
Normal file
76
coderd/httpmw/workspaceresourceparam.go
Normal file
@ -0,0 +1,76 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type workspaceResourceParamContextKey struct{}
|
||||
|
||||
// ProvisionerJobParam returns the project from the ExtractProjectParam handler.
|
||||
func WorkspaceResourceParam(r *http.Request) database.WorkspaceResource {
|
||||
resource, ok := r.Context().Value(workspaceResourceParamContextKey{}).(database.WorkspaceResource)
|
||||
if !ok {
|
||||
panic("developer error: workspace resource param middleware not provided")
|
||||
}
|
||||
return resource
|
||||
}
|
||||
|
||||
// ExtractWorkspaceResourceParam grabs a workspace resource from the "provisionerjob" URL parameter.
|
||||
func ExtractWorkspaceResourceParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
resourceUUID, parsed := parseUUID(rw, r, "workspaceresource")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
resource, err := db.GetWorkspaceResourceByID(r.Context(), resourceUUID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: "resource doesn't exist with that id",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get provisioner resource: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(r.Context(), resource.JobID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get provisioner job: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
if job.Type != database.ProvisionerJobTypeWorkspaceBuild {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "Workspace resources can only be fetched for builds.",
|
||||
})
|
||||
return
|
||||
}
|
||||
build, err := db.GetWorkspaceBuildByJobID(r.Context(), job.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get workspace build: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), workspaceResourceParamContextKey{}, resource)
|
||||
ctx = context.WithValue(ctx, workspaceBuildParamContextKey{}, build)
|
||||
chi.RouteContext(ctx).URLParams.Add("workspace", build.WorkspaceID.String())
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
122
coderd/httpmw/workspaceresourceparam_test.go
Normal file
122
coderd/httpmw/workspaceresourceparam_test.go
Normal file
@ -0,0 +1,122 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
func TestWorkspaceResourceParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setup := func(db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
|
||||
ID: uuid.New(),
|
||||
Type: jobType,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
workspaceBuild, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{
|
||||
ID: uuid.New(),
|
||||
JobID: job.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{
|
||||
ID: uuid.New(),
|
||||
JobID: job.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("workspacebuild", workspaceBuild.ID.String())
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, resource
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractWorkspaceResourceParam(db))
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractWorkspaceResourceParam(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", uuid.NewString())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("FoundBadJobType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractWorkspaceResourceParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.WorkspaceResourceParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, job := setup(db, database.ProvisionerJobTypeProjectVersionImport)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractWorkspaceResourceParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.WorkspaceResourceParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, job := setup(db, database.ProvisionerJobTypeWorkspaceBuild)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
@ -12,10 +12,10 @@ import (
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func (*api) organization(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -9,8 +9,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
)
|
||||
|
||||
// ComputeScope targets identifiers to pull parameters from.
|
||||
|
@ -7,10 +7,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/parameter"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/database/databasefake"
|
||||
)
|
||||
|
||||
func TestCompute(t *testing.T) {
|
||||
|
@ -10,9 +10,9 @@ import (
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
)
|
||||
|
||||
func (api *api) postParameter(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -8,8 +8,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
)
|
||||
|
||||
func TestPostParameter(t *testing.T) {
|
||||
|
@ -10,10 +10,10 @@ import (
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
// Returns a single project.
|
||||
|
@ -8,11 +8,11 @@ import (
|
||||
|
||||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/parameter"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func (api *api) projectVersion(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -23,9 +23,9 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/parameter"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/provisionerd/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
sdkproto "github.com/coder/coder/provisionersdk/proto"
|
||||
|
@ -15,9 +15,9 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
)
|
||||
|
||||
// Returns provisioner logs based on query parameters.
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
)
|
||||
|
@ -15,12 +15,12 @@ import (
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/userpassword"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
// Returns whether the initial user has been created or not.
|
||||
|
@ -9,8 +9,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func TestFirstUser(t *testing.T) {
|
||||
|
@ -7,10 +7,10 @@ import (
|
||||
|
||||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func (api *api) workspaceBuild(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -9,9 +9,9 @@ import (
|
||||
|
||||
"github.com/go-chi/render"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
@ -15,10 +15,10 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
|
@ -13,10 +13,10 @@ import (
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/httpapi"
|
||||
"github.com/coder/coder/httpmw"
|
||||
)
|
||||
|
||||
func (api *api) workspace(rw http.ResponseWriter, r *http.Request) {
|
||||
|
@ -9,8 +9,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
)
|
||||
|
Reference in New Issue
Block a user