chore: Move httpapi, httpmw, & database into coderd (#568)

* chore: Move httpmw to /coderd directory
httpmw is specific to coderd and should be scoped under coderd

* chore: Move httpapi to /coderd directory
httpapi is specific to coderd and should be scoped under coderd

* chore: Move database  to /coderd directory
database is specific to coderd and should be scoped under coderd

* chore: Update codecov & gitattributes for generated files
* chore: Update Makefile
This commit is contained in:
Steven Masley
2022-03-25 16:07:45 -05:00
committed by GitHub
parent 6be949a88e
commit 591523a078
98 changed files with 155 additions and 155 deletions

View File

@ -10,9 +10,9 @@ import (
"google.golang.org/api/idtoken"
"cdr.dev/slog"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/site"
)

View File

@ -31,11 +31,11 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
"github.com/coder/coder/database/postgres"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionersdk"

File diff suppressed because it is too large Load Diff

75
coderd/database/db.go Normal file
View File

@ -0,0 +1,75 @@
// Package database connects to external services for stateful storage.
//
// Query functions are generated using sqlc.
//
// To modify the database schema:
// 1. Add a new migration using "create_migration.sh" in database/migrations/
// 2. Run "make coderd/database/generate" in the root to generate models.
// 3. Add/Edit queries in "query.sql" and run "make coderd/database/generate" to create Go code.
package database
import (
"context"
"database/sql"
"errors"
"golang.org/x/xerrors"
)
// Store contains all queryable database functions.
// It extends the generated interface to add transaction support.
type Store interface {
querier
InTx(func(Store) error) error
}
// DBTX represents a database connection or transaction.
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
// New creates a new database store using a SQL database connection.
func New(sdb *sql.DB) Store {
return &sqlQuerier{
db: sdb,
sdb: sdb,
}
}
type sqlQuerier struct {
sdb *sql.DB
db DBTX
}
// InTx performs database operations inside a transaction.
func (q *sqlQuerier) InTx(function func(Store) error) error {
if q.sdb == nil {
return nil
}
transaction, err := q.sdb.Begin()
if err != nil {
return xerrors.Errorf("begin transaction: %w", err)
}
defer func() {
rerr := transaction.Rollback()
if rerr == nil || errors.Is(rerr, sql.ErrTxDone) {
// no need to do anything, tx committed successfully
return
}
// couldn't roll back for some reason, extend returned error
err = xerrors.Errorf("defer (%s): %w", rerr.Error(), err)
}()
err = function(&sqlQuerier{db: transaction})
if err != nil {
return xerrors.Errorf("execute transaction: %w", err)
}
err = transaction.Commit()
if err != nil {
return xerrors.Errorf("commit transaction: %w", err)
}
return nil
}

378
coderd/database/dump.sql generated Normal file
View File

@ -0,0 +1,378 @@
-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.
CREATE TYPE log_level AS ENUM (
'trace',
'debug',
'info',
'warn',
'error'
);
CREATE TYPE log_source AS ENUM (
'provisioner_daemon',
'provisioner'
);
CREATE TYPE login_type AS ENUM (
'built-in',
'saml',
'oidc'
);
CREATE TYPE parameter_destination_scheme AS ENUM (
'none',
'environment_variable',
'provisioner_variable'
);
CREATE TYPE parameter_scope AS ENUM (
'organization',
'project',
'import_job',
'user',
'workspace'
);
CREATE TYPE parameter_source_scheme AS ENUM (
'none',
'data'
);
CREATE TYPE parameter_type_system AS ENUM (
'none',
'hcl'
);
CREATE TYPE provisioner_job_type AS ENUM (
'project_version_import',
'workspace_build'
);
CREATE TYPE provisioner_storage_method AS ENUM (
'file'
);
CREATE TYPE provisioner_type AS ENUM (
'echo',
'terraform'
);
CREATE TYPE userstatus AS ENUM (
'active',
'dormant',
'decommissioned'
);
CREATE TYPE workspace_transition AS ENUM (
'start',
'stop',
'delete'
);
CREATE TABLE api_keys (
id text NOT NULL,
hashed_secret bytea NOT NULL,
user_id text NOT NULL,
application boolean NOT NULL,
name text NOT NULL,
last_used timestamp with time zone NOT NULL,
expires_at timestamp with time zone NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
login_type login_type NOT NULL,
oidc_access_token text DEFAULT ''::text NOT NULL,
oidc_refresh_token text DEFAULT ''::text NOT NULL,
oidc_id_token text DEFAULT ''::text NOT NULL,
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
devurl_token boolean DEFAULT false NOT NULL
);
CREATE TABLE files (
hash character varying(64) NOT NULL,
created_at timestamp with time zone NOT NULL,
created_by text NOT NULL,
mimetype character varying(64) NOT NULL,
data bytea NOT NULL
);
CREATE TABLE licenses (
id integer NOT NULL,
license jsonb NOT NULL,
created_at timestamp with time zone NOT NULL
);
CREATE TABLE organization_members (
organization_id text NOT NULL,
user_id text NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
roles text[] DEFAULT '{organization-member}'::text[] NOT NULL
);
CREATE TABLE organizations (
id text NOT NULL,
name text NOT NULL,
description text NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
"default" boolean DEFAULT false NOT NULL,
auto_off_threshold bigint DEFAULT '28800000000000'::bigint NOT NULL,
cpu_provisioning_rate real DEFAULT 4.0 NOT NULL,
memory_provisioning_rate real DEFAULT 1.0 NOT NULL,
workspace_auto_off boolean DEFAULT false NOT NULL
);
CREATE TABLE parameter_schemas (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
job_id uuid NOT NULL,
name character varying(64) NOT NULL,
description character varying(8192) DEFAULT ''::character varying NOT NULL,
default_source_scheme parameter_source_scheme,
default_source_value text NOT NULL,
allow_override_source boolean NOT NULL,
default_destination_scheme parameter_destination_scheme,
allow_override_destination boolean NOT NULL,
default_refresh text NOT NULL,
redisplay_value boolean NOT NULL,
validation_error character varying(256) NOT NULL,
validation_condition character varying(512) NOT NULL,
validation_type_system parameter_type_system NOT NULL,
validation_value_type character varying(64) NOT NULL
);
CREATE TABLE parameter_values (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
scope parameter_scope NOT NULL,
scope_id text NOT NULL,
name character varying(64) NOT NULL,
source_scheme parameter_source_scheme NOT NULL,
source_value text NOT NULL,
destination_scheme parameter_destination_scheme NOT NULL
);
CREATE TABLE project_versions (
id uuid NOT NULL,
project_id uuid,
organization_id text NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
name character varying(64) NOT NULL,
description character varying(1048576) NOT NULL,
job_id uuid NOT NULL
);
CREATE TABLE projects (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
organization_id text NOT NULL,
deleted boolean DEFAULT false NOT NULL,
name character varying(64) NOT NULL,
provisioner provisioner_type NOT NULL,
active_version_id uuid NOT NULL
);
CREATE TABLE provisioner_daemons (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone,
organization_id text,
name character varying(64) NOT NULL,
provisioners provisioner_type[] NOT NULL
);
CREATE TABLE provisioner_job_logs (
id uuid NOT NULL,
job_id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
source log_source NOT NULL,
level log_level NOT NULL,
output character varying(1024) NOT NULL
);
CREATE TABLE provisioner_jobs (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
started_at timestamp with time zone,
canceled_at timestamp with time zone,
completed_at timestamp with time zone,
error text,
organization_id text NOT NULL,
initiator_id text NOT NULL,
provisioner provisioner_type NOT NULL,
storage_method provisioner_storage_method NOT NULL,
storage_source text NOT NULL,
type provisioner_job_type NOT NULL,
input jsonb NOT NULL,
worker_id uuid
);
CREATE TABLE users (
id text NOT NULL,
email text NOT NULL,
name text NOT NULL,
revoked boolean NOT NULL,
login_type login_type NOT NULL,
hashed_password bytea NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
temporary_password boolean DEFAULT false NOT NULL,
avatar_hash text DEFAULT ''::text NOT NULL,
ssh_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
username text DEFAULT ''::text NOT NULL,
dotfiles_git_uri text DEFAULT ''::text NOT NULL,
roles text[] DEFAULT '{site-member}'::text[] NOT NULL,
status userstatus DEFAULT 'active'::public.userstatus NOT NULL,
relatime timestamp with time zone DEFAULT now() NOT NULL,
gpg_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
_decomissioned boolean DEFAULT false NOT NULL,
shell text DEFAULT ''::text NOT NULL
);
CREATE TABLE workspace_agents (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
first_connected_at timestamp with time zone,
last_connected_at timestamp with time zone,
disconnected_at timestamp with time zone,
resource_id uuid NOT NULL,
auth_token uuid NOT NULL,
auth_instance_id character varying(64),
environment_variables jsonb,
startup_script character varying(65534),
instance_metadata jsonb,
resource_metadata jsonb
);
CREATE TABLE workspace_builds (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
workspace_id uuid NOT NULL,
project_version_id uuid NOT NULL,
name character varying(64) NOT NULL,
before_id uuid,
after_id uuid,
transition workspace_transition NOT NULL,
initiator character varying(255) NOT NULL,
provisioner_state bytea,
job_id uuid NOT NULL
);
CREATE TABLE workspace_resources (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
job_id uuid NOT NULL,
transition workspace_transition NOT NULL,
address character varying(256) NOT NULL,
type character varying(192) NOT NULL,
name character varying(64) NOT NULL,
agent_id uuid
);
CREATE TABLE workspaces (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
owner_id text NOT NULL,
project_id uuid NOT NULL,
deleted boolean DEFAULT false NOT NULL,
name character varying(64) NOT NULL
);
ALTER TABLE ONLY files
ADD CONSTRAINT files_hash_key UNIQUE (hash);
ALTER TABLE ONLY parameter_schemas
ADD CONSTRAINT parameter_schemas_id_key UNIQUE (id);
ALTER TABLE ONLY parameter_schemas
ADD CONSTRAINT parameter_schemas_job_id_name_key UNIQUE (job_id, name);
ALTER TABLE ONLY parameter_values
ADD CONSTRAINT parameter_values_id_key UNIQUE (id);
ALTER TABLE ONLY parameter_values
ADD CONSTRAINT parameter_values_scope_id_name_key UNIQUE (scope_id, name);
ALTER TABLE ONLY project_versions
ADD CONSTRAINT project_versions_id_key UNIQUE (id);
ALTER TABLE ONLY project_versions
ADD CONSTRAINT project_versions_project_id_name_key UNIQUE (project_id, name);
ALTER TABLE ONLY projects
ADD CONSTRAINT projects_id_key UNIQUE (id);
ALTER TABLE ONLY projects
ADD CONSTRAINT projects_organization_id_name_key UNIQUE (organization_id, name);
ALTER TABLE ONLY provisioner_daemons
ADD CONSTRAINT provisioner_daemons_id_key UNIQUE (id);
ALTER TABLE ONLY provisioner_daemons
ADD CONSTRAINT provisioner_daemons_name_key UNIQUE (name);
ALTER TABLE ONLY provisioner_job_logs
ADD CONSTRAINT provisioner_job_logs_id_key UNIQUE (id);
ALTER TABLE ONLY provisioner_jobs
ADD CONSTRAINT provisioner_jobs_id_key UNIQUE (id);
ALTER TABLE ONLY workspace_agents
ADD CONSTRAINT workspace_agents_auth_token_key UNIQUE (auth_token);
ALTER TABLE ONLY workspace_agents
ADD CONSTRAINT workspace_agents_id_key UNIQUE (id);
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_id_key UNIQUE (id);
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_job_id_key UNIQUE (job_id);
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_workspace_id_name_key UNIQUE (workspace_id, name);
ALTER TABLE ONLY workspace_resources
ADD CONSTRAINT workspace_resources_id_key UNIQUE (id);
ALTER TABLE ONLY workspaces
ADD CONSTRAINT workspaces_id_key UNIQUE (id);
CREATE UNIQUE INDEX projects_organization_id_name_idx ON projects USING btree (organization_id, name) WHERE (deleted = false);
CREATE UNIQUE INDEX workspaces_owner_id_name_idx ON workspaces USING btree (owner_id, name) WHERE (deleted = false);
ALTER TABLE ONLY parameter_schemas
ADD CONSTRAINT parameter_schemas_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ALTER TABLE ONLY project_versions
ADD CONSTRAINT project_versions_project_id_fkey FOREIGN KEY (project_id) REFERENCES projects(id);
ALTER TABLE ONLY provisioner_job_logs
ADD CONSTRAINT provisioner_job_logs_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_agents
ADD CONSTRAINT workspace_agents_resource_id_fkey FOREIGN KEY (resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_project_version_id_fkey FOREIGN KEY (project_version_id) REFERENCES project_versions(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_resources
ADD CONSTRAINT workspace_resources_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspaces
ADD CONSTRAINT workspaces_project_id_fkey FOREIGN KEY (project_id) REFERENCES projects(id);

View File

@ -0,0 +1,88 @@
package main
import (
"bytes"
"database/sql"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/postgres"
)
func main() {
connection, closeFn, err := postgres.Open()
if err != nil {
panic(err)
}
defer closeFn()
db, err := sql.Open("postgres", connection)
if err != nil {
panic(err)
}
err = database.MigrateUp(db)
if err != nil {
panic(err)
}
cmd := exec.Command(
"pg_dump",
"--schema-only",
connection,
"--no-privileges",
"--no-owner",
"--no-comments",
// We never want to manually generate
// queries executing against this table.
"--exclude-table=schema_migrations",
)
cmd.Env = []string{
"PGTZ=UTC",
"PGCLIENTENCODING=UTF8",
}
var output bytes.Buffer
cmd.Stdout = &output
cmd.Stderr = os.Stderr
err = cmd.Run()
if err != nil {
panic(err)
}
for _, sed := range []string{
// Remove all comments.
"/^--/d",
// Public is implicit in the schema.
"s/ public\\./ /",
// Remove database settings.
"s/SET.*;//g",
// Remove select statements. These aren't useful
// to a reader of the dump.
"s/SELECT.*;//g",
// Removes multiple newlines.
"/^$/N;/^\\n$/D",
} {
cmd := exec.Command("sed", "-e", sed)
cmd.Stdin = bytes.NewReader(output.Bytes())
output = bytes.Buffer{}
cmd.Stdout = &output
cmd.Stderr = os.Stderr
err = cmd.Run()
if err != nil {
panic(err)
}
}
dump := fmt.Sprintf("-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.\n%s", output.Bytes())
_, mainPath, _, ok := runtime.Caller(0)
if !ok {
panic("couldn't get caller path")
}
err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0600)
if err != nil {
panic(err)
}
}

View File

@ -0,0 +1,74 @@
package database
import (
"database/sql"
"embed"
"errors"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
"golang.org/x/xerrors"
)
//go:embed migrations/*.sql
var migrations embed.FS
func migrateSetup(db *sql.DB) (*migrate.Migrate, error) {
sourceDriver, err := iofs.New(migrations, "migrations")
if err != nil {
return nil, xerrors.Errorf("create iofs: %w", err)
}
dbDriver, err := postgres.WithInstance(db, &postgres.Config{})
if err != nil {
return nil, xerrors.Errorf("wrap postgres connection: %w", err)
}
m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver)
if err != nil {
return nil, xerrors.Errorf("new migrate instance: %w", err)
}
return m, nil
}
// MigrateUp runs SQL migrations to ensure the database schema is up-to-date.
func MigrateUp(db *sql.DB) error {
m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
err = m.Up()
if err != nil {
if errors.Is(err, migrate.ErrNoChange) {
// It's OK if no changes happened!
return nil
}
return xerrors.Errorf("up: %w", err)
}
return nil
}
// MigrateDown runs all down SQL migrations.
func MigrateDown(db *sql.DB) error {
m, err := migrateSetup(db)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
err = m.Down()
if err != nil {
if errors.Is(err, migrate.ErrNoChange) {
// It's OK if no changes happened!
return nil
}
return xerrors.Errorf("down: %w", err)
}
return nil
}

View File

@ -0,0 +1,77 @@
//go:build linux
package database_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/postgres"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestMigrate(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip()
return
}
t.Run("Once", func(t *testing.T) {
t.Parallel()
db := testSQLDB(t)
err := database.MigrateUp(db)
require.NoError(t, err)
})
t.Run("Twice", func(t *testing.T) {
t.Parallel()
db := testSQLDB(t)
err := database.MigrateUp(db)
require.NoError(t, err)
err = database.MigrateUp(db)
require.NoError(t, err)
})
t.Run("UpDownUp", func(t *testing.T) {
t.Parallel()
db := testSQLDB(t)
err := database.MigrateUp(db)
require.NoError(t, err)
err = database.MigrateDown(db)
require.NoError(t, err)
err = database.MigrateUp(db)
require.NoError(t, err)
})
}
func testSQLDB(t testing.TB) *sql.DB {
t.Helper()
connection, closeFn, err := postgres.Open()
require.NoError(t, err)
t.Cleanup(closeFn)
db, err := sql.Open("postgres", connection)
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
return db
}

View File

@ -0,0 +1,92 @@
-- This migration creates tables and types for v1 if they do not exist.
-- This allows v2 to operate independently of v1, but share data if it exists.
--
-- All tables and types are stolen from:
-- https://github.com/coder/m/blob/47b6fc383347b9f9fab424d829c482defd3e1fe2/product/coder/pkg/database/dump.sql
DO $$ BEGIN
CREATE TYPE login_type AS ENUM (
'built-in',
'saml',
'oidc'
);
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
DO $$ BEGIN
CREATE TYPE userstatus AS ENUM (
'active',
'dormant',
'decommissioned'
);
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
CREATE TABLE IF NOT EXISTS users (
id text NOT NULL,
email text NOT NULL,
name text NOT NULL,
revoked boolean NOT NULL,
login_type login_type NOT NULL,
hashed_password bytea NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
temporary_password boolean DEFAULT false NOT NULL,
avatar_hash text DEFAULT '' :: text NOT NULL,
ssh_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
username text DEFAULT '' :: text NOT NULL,
dotfiles_git_uri text DEFAULT '' :: text NOT NULL,
roles text [] DEFAULT '{site-member}' :: text [] NOT NULL,
status userstatus DEFAULT 'active' :: public.userstatus NOT NULL,
relatime timestamp with time zone DEFAULT now() NOT NULL,
gpg_key_regenerated_at timestamp with time zone DEFAULT now() NOT NULL,
_decomissioned boolean DEFAULT false NOT NULL,
shell text DEFAULT '' :: text NOT NULL
);
CREATE TABLE IF NOT EXISTS organizations (
id text NOT NULL,
name text NOT NULL,
description text NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
"default" boolean DEFAULT false NOT NULL,
auto_off_threshold bigint DEFAULT '28800000000000' :: bigint NOT NULL,
cpu_provisioning_rate real DEFAULT 4.0 NOT NULL,
memory_provisioning_rate real DEFAULT 1.0 NOT NULL,
workspace_auto_off boolean DEFAULT false NOT NULL
);
CREATE TABLE IF NOT EXISTS organization_members (
organization_id text NOT NULL,
user_id text NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
roles text [] DEFAULT '{organization-member}' :: text [] NOT NULL
);
CREATE TABLE IF NOT EXISTS api_keys (
id text NOT NULL,
hashed_secret bytea NOT NULL,
user_id text NOT NULL,
application boolean NOT NULL,
name text NOT NULL,
last_used timestamp with time zone NOT NULL,
expires_at timestamp with time zone NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
login_type login_type NOT NULL,
oidc_access_token text DEFAULT ''::text NOT NULL,
oidc_refresh_token text DEFAULT ''::text NOT NULL,
oidc_id_token text DEFAULT ''::text NOT NULL,
oidc_expiry timestamp with time zone DEFAULT '0001-01-01 00:00:00+00'::timestamp with time zone NOT NULL,
devurl_token boolean DEFAULT false NOT NULL
);
CREATE TABLE IF NOT EXISTS licenses (
id integer NOT NULL,
license jsonb NOT NULL,
created_at timestamp with time zone NOT NULL
);

View File

@ -0,0 +1,6 @@
DROP TABLE project_versions;
DROP TABLE projects;
DROP TYPE provisioner_type;
DROP TABLE files;

View File

@ -0,0 +1,55 @@
-- Store arbitrary data like project source code or avatars.
CREATE TABLE files (
hash varchar(64) NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
created_by text NOT NULL,
mimetype varchar(64) NOT NULL,
data bytea NOT NULL
);
CREATE TYPE provisioner_type AS ENUM ('echo', 'terraform');
-- Project defines infrastructure that your software project
-- requires for development.
CREATE TABLE projects (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
-- Projects must be scoped to an organization.
organization_id text NOT NULL,
deleted boolean NOT NULL DEFAULT FALSE,
name varchar(64) NOT NULL,
provisioner provisioner_type NOT NULL,
-- Target's a Project Version to use for Workspaces.
-- If a Workspace doesn't match this version, it will be prompted to rebuild.
active_version_id uuid NOT NULL,
-- Disallow projects to have the same name under
-- the same organization.
UNIQUE(organization_id, name)
);
-- Enforces no active projects have the same name.
CREATE UNIQUE INDEX ON projects (organization_id, name) WHERE deleted = FALSE;
-- Project Versions store historical project data. When a Project Version is imported,
-- an "import" job is queued to parse parameters. A Project Version
-- can only be used if the import job succeeds.
CREATE TABLE project_versions (
id uuid NOT NULL UNIQUE,
-- This should be indexed.
project_id uuid REFERENCES projects (id),
organization_id text NOT NULL,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
-- Name is generated for ease of differentiation.
-- eg. TheCozyRabbit16
name varchar(64) NOT NULL,
-- Extracted from a README.md on import.
-- Maximum of 1MB.
description varchar(1048576) NOT NULL,
-- The job ID for building the project version.
job_id uuid NOT NULL,
-- Disallow projects to have the same build name
-- multiple times.
UNIQUE(project_id, name)
);

View File

@ -0,0 +1,2 @@
DROP TYPE workspace_transition;
DROP TABLE workspaces

View File

@ -0,0 +1,19 @@
CREATE TABLE workspaces (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
owner_id text NOT NULL,
project_id uuid NOT NULL REFERENCES projects (id),
deleted boolean NOT NULL DEFAULT FALSE,
name varchar(64) NOT NULL
);
-- Enforces no active workspaces have the same name.
CREATE UNIQUE INDEX ON workspaces (owner_id, name) WHERE deleted = FALSE;
CREATE TYPE workspace_transition AS ENUM (
'start',
'stop',
'delete'
);

View File

@ -0,0 +1,21 @@
DROP TABLE workspace_builds;
DROP TABLE parameter_values;
DROP TABLE parameter_schemas;
DROP TYPE parameter_destination_scheme;
DROP TYPE parameter_source_scheme;
DROP TYPE parameter_type_system;
DROP TYPE parameter_scope;
DROP TABLE workspace_agents;
DROP TABLE workspace_resources;
DROP TABLE provisioner_job_logs;
DROP TYPE log_source;
DROP TYPE log_level;
DROP TABLE provisioner_jobs;
DROP TYPE provisioner_storage_method;
DROP TYPE provisioner_job_type;
DROP TABLE provisioner_daemons;

View File

@ -0,0 +1,167 @@
CREATE TABLE IF NOT EXISTS provisioner_daemons (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz,
organization_id text,
-- Name is generated for ease of differentiation.
-- eg. WowBananas16
name varchar(64) NOT NULL UNIQUE,
provisioners provisioner_type [ ] NOT NULL
);
CREATE TYPE provisioner_job_type AS ENUM (
'project_version_import',
'workspace_build'
);
CREATE TYPE provisioner_storage_method AS ENUM ('file');
CREATE TABLE IF NOT EXISTS provisioner_jobs (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
started_at timestamptz,
canceled_at timestamptz,
completed_at timestamptz,
error text,
organization_id text NOT NULL,
initiator_id text NOT NULL,
provisioner provisioner_type NOT NULL,
storage_method provisioner_storage_method NOT NULL,
storage_source text NOT NULL,
type provisioner_job_type NOT NULL,
input jsonb NOT NULL,
worker_id uuid
);
CREATE TYPE log_level AS ENUM (
'trace',
'debug',
'info',
'warn',
'error'
);
CREATE TYPE log_source AS ENUM (
'provisioner_daemon',
'provisioner'
);
CREATE TABLE IF NOT EXISTS provisioner_job_logs (
id uuid NOT NULL UNIQUE,
job_id uuid NOT NULL REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
created_at timestamptz NOT NULL,
source log_source NOT NULL,
level log_level NOT NULL,
output varchar(1024) NOT NULL
);
CREATE TABLE workspace_resources (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
job_id uuid NOT NULL REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
transition workspace_transition NOT NULL,
address varchar(256) NOT NULL,
type varchar(192) NOT NULL,
name varchar(64) NOT NULL,
agent_id uuid
);
CREATE TABLE workspace_agents (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
first_connected_at timestamptz,
last_connected_at timestamptz,
disconnected_at timestamptz,
resource_id uuid NOT NULL REFERENCES workspace_resources (id) ON DELETE CASCADE,
auth_token uuid NOT NULL UNIQUE,
auth_instance_id varchar(64),
environment_variables jsonb,
startup_script varchar(65534),
instance_metadata jsonb,
resource_metadata jsonb
);
CREATE TYPE parameter_scope AS ENUM (
'organization',
'project',
'import_job',
'user',
'workspace'
);
-- Types of parameters the automator supports.
CREATE TYPE parameter_type_system AS ENUM ('none', 'hcl');
-- Supported schemes for a parameter source.
CREATE TYPE parameter_source_scheme AS ENUM('none', 'data');
-- Supported schemes for a parameter destination.
CREATE TYPE parameter_destination_scheme AS ENUM('none', 'environment_variable', 'provisioner_variable');
-- Stores project version parameters parsed on import.
-- No secrets are stored here.
--
-- All parameter validation occurs server-side to process
-- complex validations.
--
-- Parameter types, description, and validation will produce
-- a UI for users to enter values.
-- Needs to be made consistent with the examples below.
CREATE TABLE parameter_schemas (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
job_id uuid NOT NULL REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
name varchar(64) NOT NULL,
description varchar(8192) NOT NULL DEFAULT '',
default_source_scheme parameter_source_scheme,
default_source_value text NOT NULL,
-- Allows the user to override the source.
allow_override_source boolean NOT null,
default_destination_scheme parameter_destination_scheme,
-- Allows the user to override the destination.
allow_override_destination boolean NOT null,
default_refresh text NOT NULL,
-- Whether the consumer can view the source and destinations.
redisplay_value boolean NOT null,
-- This error would appear in the UI if the condition is not met.
validation_error varchar(256) NOT NULL,
validation_condition varchar(512) NOT NULL,
validation_type_system parameter_type_system NOT NULL,
validation_value_type varchar(64) NOT NULL,
UNIQUE(job_id, name)
);
-- Parameters are provided to jobs for provisioning and to workspaces.
CREATE TABLE parameter_values (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
scope parameter_scope NOT NULL,
scope_id text NOT NULL,
name varchar(64) NOT NULL,
source_scheme parameter_source_scheme NOT NULL,
source_value text NOT NULL,
destination_scheme parameter_destination_scheme NOT NULL,
-- Prevents duplicates for parameters in the same scope.
UNIQUE(scope_id, name)
);
CREATE TABLE workspace_builds (
id uuid NOT NULL UNIQUE,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
workspace_id uuid NOT NULL REFERENCES workspaces (id) ON DELETE CASCADE,
project_version_id uuid NOT NULL REFERENCES project_versions (id) ON DELETE CASCADE,
name varchar(64) NOT NULL,
before_id uuid,
after_id uuid,
transition workspace_transition NOT NULL,
initiator varchar(255) NOT NULL,
-- State stored by the provisioner
provisioner_state bytea,
-- Job ID of the action
job_id uuid NOT NULL UNIQUE REFERENCES provisioner_jobs (id) ON DELETE CASCADE,
UNIQUE(workspace_id, name)
);

View File

@ -0,0 +1,14 @@
#!/usr/bin/env bash
set -euo pipefail
cd "$(dirname "$0")"
if [ -z "$1" ]; then
echo "First argument is the migration name!"
exit 1
fi
migrate create -ext sql -dir . -seq "$1"
echo "Run \"make gen\" to generate models."

466
coderd/database/models.go Normal file
View File

@ -0,0 +1,466 @@
// Code generated by sqlc. DO NOT EDIT.
package database
import (
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/tabbed/pqtype"
)
type LogLevel string
const (
LogLevelTrace LogLevel = "trace"
LogLevelDebug LogLevel = "debug"
LogLevelInfo LogLevel = "info"
LogLevelWarn LogLevel = "warn"
LogLevelError LogLevel = "error"
)
func (e *LogLevel) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = LogLevel(s)
case string:
*e = LogLevel(s)
default:
return fmt.Errorf("unsupported scan type for LogLevel: %T", src)
}
return nil
}
type LogSource string
const (
LogSourceProvisionerDaemon LogSource = "provisioner_daemon"
LogSourceProvisioner LogSource = "provisioner"
)
func (e *LogSource) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = LogSource(s)
case string:
*e = LogSource(s)
default:
return fmt.Errorf("unsupported scan type for LogSource: %T", src)
}
return nil
}
type LoginType string
const (
LoginTypeBuiltIn LoginType = "built-in"
LoginTypeSaml LoginType = "saml"
LoginTypeOIDC LoginType = "oidc"
)
func (e *LoginType) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = LoginType(s)
case string:
*e = LoginType(s)
default:
return fmt.Errorf("unsupported scan type for LoginType: %T", src)
}
return nil
}
type ParameterDestinationScheme string
const (
ParameterDestinationSchemeNone ParameterDestinationScheme = "none"
ParameterDestinationSchemeEnvironmentVariable ParameterDestinationScheme = "environment_variable"
ParameterDestinationSchemeProvisionerVariable ParameterDestinationScheme = "provisioner_variable"
)
func (e *ParameterDestinationScheme) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ParameterDestinationScheme(s)
case string:
*e = ParameterDestinationScheme(s)
default:
return fmt.Errorf("unsupported scan type for ParameterDestinationScheme: %T", src)
}
return nil
}
type ParameterScope string
const (
ParameterScopeOrganization ParameterScope = "organization"
ParameterScopeProject ParameterScope = "project"
ParameterScopeImportJob ParameterScope = "import_job"
ParameterScopeUser ParameterScope = "user"
ParameterScopeWorkspace ParameterScope = "workspace"
)
func (e *ParameterScope) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ParameterScope(s)
case string:
*e = ParameterScope(s)
default:
return fmt.Errorf("unsupported scan type for ParameterScope: %T", src)
}
return nil
}
type ParameterSourceScheme string
const (
ParameterSourceSchemeNone ParameterSourceScheme = "none"
ParameterSourceSchemeData ParameterSourceScheme = "data"
)
func (e *ParameterSourceScheme) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ParameterSourceScheme(s)
case string:
*e = ParameterSourceScheme(s)
default:
return fmt.Errorf("unsupported scan type for ParameterSourceScheme: %T", src)
}
return nil
}
type ParameterTypeSystem string
const (
ParameterTypeSystemNone ParameterTypeSystem = "none"
ParameterTypeSystemHCL ParameterTypeSystem = "hcl"
)
func (e *ParameterTypeSystem) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ParameterTypeSystem(s)
case string:
*e = ParameterTypeSystem(s)
default:
return fmt.Errorf("unsupported scan type for ParameterTypeSystem: %T", src)
}
return nil
}
type ProvisionerJobType string
const (
ProvisionerJobTypeProjectVersionImport ProvisionerJobType = "project_version_import"
ProvisionerJobTypeWorkspaceBuild ProvisionerJobType = "workspace_build"
)
func (e *ProvisionerJobType) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ProvisionerJobType(s)
case string:
*e = ProvisionerJobType(s)
default:
return fmt.Errorf("unsupported scan type for ProvisionerJobType: %T", src)
}
return nil
}
type ProvisionerStorageMethod string
const (
ProvisionerStorageMethodFile ProvisionerStorageMethod = "file"
)
func (e *ProvisionerStorageMethod) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ProvisionerStorageMethod(s)
case string:
*e = ProvisionerStorageMethod(s)
default:
return fmt.Errorf("unsupported scan type for ProvisionerStorageMethod: %T", src)
}
return nil
}
type ProvisionerType string
const (
ProvisionerTypeEcho ProvisionerType = "echo"
ProvisionerTypeTerraform ProvisionerType = "terraform"
)
func (e *ProvisionerType) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = ProvisionerType(s)
case string:
*e = ProvisionerType(s)
default:
return fmt.Errorf("unsupported scan type for ProvisionerType: %T", src)
}
return nil
}
type UserStatus string
const (
UserstatusActive UserStatus = "active"
UserstatusDormant UserStatus = "dormant"
UserstatusDecommissioned UserStatus = "decommissioned"
)
func (e *UserStatus) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = UserStatus(s)
case string:
*e = UserStatus(s)
default:
return fmt.Errorf("unsupported scan type for UserStatus: %T", src)
}
return nil
}
type WorkspaceTransition string
const (
WorkspaceTransitionStart WorkspaceTransition = "start"
WorkspaceTransitionStop WorkspaceTransition = "stop"
WorkspaceTransitionDelete WorkspaceTransition = "delete"
)
func (e *WorkspaceTransition) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = WorkspaceTransition(s)
case string:
*e = WorkspaceTransition(s)
default:
return fmt.Errorf("unsupported scan type for WorkspaceTransition: %T", src)
}
return nil
}
type APIKey struct {
ID string `db:"id" json:"id"`
HashedSecret []byte `db:"hashed_secret" json:"hashed_secret"`
UserID string `db:"user_id" json:"user_id"`
Application bool `db:"application" json:"application"`
Name string `db:"name" json:"name"`
LastUsed time.Time `db:"last_used" json:"last_used"`
ExpiresAt time.Time `db:"expires_at" json:"expires_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
LoginType LoginType `db:"login_type" json:"login_type"`
OIDCAccessToken string `db:"oidc_access_token" json:"oidc_access_token"`
OIDCRefreshToken string `db:"oidc_refresh_token" json:"oidc_refresh_token"`
OIDCIDToken string `db:"oidc_id_token" json:"oidc_id_token"`
OIDCExpiry time.Time `db:"oidc_expiry" json:"oidc_expiry"`
DevurlToken bool `db:"devurl_token" json:"devurl_token"`
}
type File struct {
Hash string `db:"hash" json:"hash"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
CreatedBy string `db:"created_by" json:"created_by"`
Mimetype string `db:"mimetype" json:"mimetype"`
Data []byte `db:"data" json:"data"`
}
type License struct {
ID int32 `db:"id" json:"id"`
License json.RawMessage `db:"license" json:"license"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
type Organization struct {
ID string `db:"id" json:"id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Default bool `db:"default" json:"default"`
AutoOffThreshold int64 `db:"auto_off_threshold" json:"auto_off_threshold"`
CpuProvisioningRate float32 `db:"cpu_provisioning_rate" json:"cpu_provisioning_rate"`
MemoryProvisioningRate float32 `db:"memory_provisioning_rate" json:"memory_provisioning_rate"`
WorkspaceAutoOff bool `db:"workspace_auto_off" json:"workspace_auto_off"`
}
type OrganizationMember struct {
OrganizationID string `db:"organization_id" json:"organization_id"`
UserID string `db:"user_id" json:"user_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Roles []string `db:"roles" json:"roles"`
}
type ParameterSchema struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
JobID uuid.UUID `db:"job_id" json:"job_id"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
DefaultSourceScheme ParameterSourceScheme `db:"default_source_scheme" json:"default_source_scheme"`
DefaultSourceValue string `db:"default_source_value" json:"default_source_value"`
AllowOverrideSource bool `db:"allow_override_source" json:"allow_override_source"`
DefaultDestinationScheme ParameterDestinationScheme `db:"default_destination_scheme" json:"default_destination_scheme"`
AllowOverrideDestination bool `db:"allow_override_destination" json:"allow_override_destination"`
DefaultRefresh string `db:"default_refresh" json:"default_refresh"`
RedisplayValue bool `db:"redisplay_value" json:"redisplay_value"`
ValidationError string `db:"validation_error" json:"validation_error"`
ValidationCondition string `db:"validation_condition" json:"validation_condition"`
ValidationTypeSystem ParameterTypeSystem `db:"validation_type_system" json:"validation_type_system"`
ValidationValueType string `db:"validation_value_type" json:"validation_value_type"`
}
type ParameterValue struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Scope ParameterScope `db:"scope" json:"scope"`
ScopeID string `db:"scope_id" json:"scope_id"`
Name string `db:"name" json:"name"`
SourceScheme ParameterSourceScheme `db:"source_scheme" json:"source_scheme"`
SourceValue string `db:"source_value" json:"source_value"`
DestinationScheme ParameterDestinationScheme `db:"destination_scheme" json:"destination_scheme"`
}
type Project struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OrganizationID string `db:"organization_id" json:"organization_id"`
Deleted bool `db:"deleted" json:"deleted"`
Name string `db:"name" json:"name"`
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
ActiveVersionID uuid.UUID `db:"active_version_id" json:"active_version_id"`
}
type ProjectVersion struct {
ID uuid.UUID `db:"id" json:"id"`
ProjectID uuid.NullUUID `db:"project_id" json:"project_id"`
OrganizationID string `db:"organization_id" json:"organization_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
Description string `db:"description" json:"description"`
JobID uuid.UUID `db:"job_id" json:"job_id"`
}
type ProvisionerDaemon struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
OrganizationID sql.NullString `db:"organization_id" json:"organization_id"`
Name string `db:"name" json:"name"`
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
}
type ProvisionerJob struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
CanceledAt sql.NullTime `db:"canceled_at" json:"canceled_at"`
CompletedAt sql.NullTime `db:"completed_at" json:"completed_at"`
Error sql.NullString `db:"error" json:"error"`
OrganizationID string `db:"organization_id" json:"organization_id"`
InitiatorID string `db:"initiator_id" json:"initiator_id"`
Provisioner ProvisionerType `db:"provisioner" json:"provisioner"`
StorageMethod ProvisionerStorageMethod `db:"storage_method" json:"storage_method"`
StorageSource string `db:"storage_source" json:"storage_source"`
Type ProvisionerJobType `db:"type" json:"type"`
Input json.RawMessage `db:"input" json:"input"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
}
type ProvisionerJobLog struct {
ID uuid.UUID `db:"id" json:"id"`
JobID uuid.UUID `db:"job_id" json:"job_id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
Source LogSource `db:"source" json:"source"`
Level LogLevel `db:"level" json:"level"`
Output string `db:"output" json:"output"`
}
type User struct {
ID string `db:"id" json:"id"`
Email string `db:"email" json:"email"`
Name string `db:"name" json:"name"`
Revoked bool `db:"revoked" json:"revoked"`
LoginType LoginType `db:"login_type" json:"login_type"`
HashedPassword []byte `db:"hashed_password" json:"hashed_password"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
TemporaryPassword bool `db:"temporary_password" json:"temporary_password"`
AvatarHash string `db:"avatar_hash" json:"avatar_hash"`
SshKeyRegeneratedAt time.Time `db:"ssh_key_regenerated_at" json:"ssh_key_regenerated_at"`
Username string `db:"username" json:"username"`
DotfilesGitUri string `db:"dotfiles_git_uri" json:"dotfiles_git_uri"`
Roles []string `db:"roles" json:"roles"`
Status UserStatus `db:"status" json:"status"`
Relatime time.Time `db:"relatime" json:"relatime"`
GpgKeyRegeneratedAt time.Time `db:"gpg_key_regenerated_at" json:"gpg_key_regenerated_at"`
Decomissioned bool `db:"_decomissioned" json:"_decomissioned"`
Shell string `db:"shell" json:"shell"`
}
type Workspace struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
OwnerID string `db:"owner_id" json:"owner_id"`
ProjectID uuid.UUID `db:"project_id" json:"project_id"`
Deleted bool `db:"deleted" json:"deleted"`
Name string `db:"name" json:"name"`
}
type WorkspaceAgent struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
}
type WorkspaceBuild struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
ProjectVersionID uuid.UUID `db:"project_version_id" json:"project_version_id"`
Name string `db:"name" json:"name"`
BeforeID uuid.NullUUID `db:"before_id" json:"before_id"`
AfterID uuid.NullUUID `db:"after_id" json:"after_id"`
Transition WorkspaceTransition `db:"transition" json:"transition"`
Initiator string `db:"initiator" json:"initiator"`
ProvisionerState []byte `db:"provisioner_state" json:"provisioner_state"`
JobID uuid.UUID `db:"job_id" json:"job_id"`
}
type WorkspaceResource struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
JobID uuid.UUID `db:"job_id" json:"job_id"`
Transition WorkspaceTransition `db:"transition" json:"transition"`
Address string `db:"address" json:"address"`
Type string `db:"type" json:"type"`
Name string `db:"name" json:"name"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
}

View File

@ -0,0 +1,142 @@
package postgres
import (
"database/sql"
"fmt"
"io/ioutil"
"net"
"os"
"sync"
"time"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"golang.org/x/xerrors"
"github.com/coder/coder/cryptorand"
)
// Required to prevent port collision during container creation.
// Super unlikely, but it happened. See: https://github.com/coder/coder/runs/5375197003
var openPortMutex sync.Mutex
// Open creates a new PostgreSQL server using a Docker container.
func Open() (string, func(), error) {
if os.Getenv("DB") == "ci" {
// In CI, creating a Docker container for each test is slow.
// This expects a PostgreSQL instance with the hardcoded credentials
// available.
dbURL := "postgres://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable"
db, err := sql.Open("postgres", dbURL)
if err != nil {
return "", nil, xerrors.Errorf("connect to ci postgres: %w", err)
}
defer db.Close()
dbName, err := cryptorand.StringCharset(cryptorand.Lower, 10)
if err != nil {
return "", nil, xerrors.Errorf("generate db name: %w", err)
}
dbName = "ci" + dbName
_, err = db.Exec("CREATE DATABASE " + dbName)
if err != nil {
return "", nil, xerrors.Errorf("create db: %w", err)
}
return "postgres://postgres:postgres@127.0.0.1:5432/" + dbName + "?sslmode=disable", func() {}, nil
}
pool, err := dockertest.NewPool("")
if err != nil {
return "", nil, xerrors.Errorf("create pool: %w", err)
}
tempDir, err := ioutil.TempDir(os.TempDir(), "postgres")
if err != nil {
return "", nil, xerrors.Errorf("create tempdir: %w", err)
}
openPortMutex.Lock()
// Pick an explicit port on the host to connect to 5432.
// This is necessary so we can configure the port to only use ipv4.
port, err := getFreePort()
if err != nil {
openPortMutex.Unlock()
return "", nil, xerrors.Errorf("Unable to get free port: %w", err)
}
resource, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "postgres",
Tag: "11",
Env: []string{
"POSTGRES_PASSWORD=postgres",
"POSTGRES_USER=postgres",
"POSTGRES_DB=postgres",
// The location for temporary database files!
"PGDATA=/tmp",
"listen_addresses = '*'",
},
PortBindings: map[docker.Port][]docker.PortBinding{
"5432/tcp": {{
// Manually specifying a host IP tells Docker just to use an IPV4 address.
// If we don't do this, we hit a fun bug:
// https://github.com/moby/moby/issues/42442
// where the ipv4 and ipv6 ports might be _different_ and collide with other running docker containers.
HostIP: "0.0.0.0",
HostPort: fmt.Sprintf("%d", port)}},
},
Mounts: []string{
// The postgres image has a VOLUME parameter in it's image.
// If we don't mount at this point, Docker will allocate a
// volume for this directory.
//
// This isn't used anyways, since we override PGDATA.
fmt.Sprintf("%s:/var/lib/postgresql/data", tempDir),
},
}, func(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
openPortMutex.Unlock()
return "", nil, xerrors.Errorf("could not start resource: %w", err)
}
openPortMutex.Unlock()
hostAndPort := resource.GetHostPort("5432/tcp")
dbURL := fmt.Sprintf("postgres://postgres:postgres@%s/postgres?sslmode=disable", hostAndPort)
// Docker should hard-kill the container after 120 seconds.
err = resource.Expire(120)
if err != nil {
return "", nil, xerrors.Errorf("could not expire resource: %w", err)
}
pool.MaxWait = 120 * time.Second
err = pool.Retry(func() error {
db, err := sql.Open("postgres", dbURL)
if err != nil {
return err
}
err = db.Ping()
_ = db.Close()
return err
})
if err != nil {
return "", nil, err
}
return dbURL, func() {
_ = pool.Purge(resource)
_ = os.RemoveAll(tempDir)
}, nil
}
// getFreePort asks the kernel for a free open port that is ready to use.
func getFreePort() (port int, err error) {
// Binding to port 0 tells the OS to grab a port for us:
// https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return 0, err
}
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port, nil
}

View File

@ -0,0 +1,38 @@
//go:build linux
package postgres_test
import (
"database/sql"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/coder/coder/coderd/database/postgres"
_ "github.com/lib/pq"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestPostgres(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip()
return
}
connect, close, err := postgres.Open()
require.NoError(t, err)
defer close()
db, err := sql.Open("postgres", connect)
require.NoError(t, err)
err = db.Ping()
require.NoError(t, err)
err = db.Close()
require.NoError(t, err)
}

155
coderd/database/pubsub.go Normal file
View File

@ -0,0 +1,155 @@
package database
import (
"context"
"database/sql"
"errors"
"sync"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"golang.org/x/xerrors"
)
// Listener represents a pubsub handler.
type Listener func(ctx context.Context, message []byte)
// Pubsub is a generic interface for broadcasting and receiving messages.
// Implementors should assume high-availability with the backing implementation.
type Pubsub interface {
Subscribe(event string, listener Listener) (cancel func(), err error)
Publish(event string, message []byte) error
Close() error
}
// Pubsub implementation using PostgreSQL.
type pgPubsub struct {
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
listeners map[string]map[string]Listener
}
// Subscribe calls the listener when an event matching the name is received.
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
p.mut.Lock()
defer p.mut.Unlock()
err = p.pgListener.Listen(event)
if errors.Is(err, pq.ErrChannelAlreadyOpen) {
// It's ok if it's already open!
err = nil
}
if err != nil {
return nil, xerrors.Errorf("listen: %w", err)
}
var listeners map[string]Listener
var ok bool
if listeners, ok = p.listeners[event]; !ok {
listeners = map[string]Listener{}
p.listeners[event] = listeners
}
var id string
for {
id = uuid.New().String()
if _, ok = listeners[id]; !ok {
break
}
}
listeners[id] = listener
return func() {
p.mut.Lock()
defer p.mut.Unlock()
listeners := p.listeners[event]
delete(listeners, id)
if len(listeners) == 0 {
_ = p.pgListener.Unlisten(event)
}
}, nil
}
func (p *pgPubsub) Publish(event string, message []byte) error {
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
return xerrors.Errorf("exec: %w", err)
}
return nil
}
// Close closes the pubsub instance.
func (p *pgPubsub) Close() error {
return p.pgListener.Close()
}
// listen begins receiving messages on the pq listener.
func (p *pgPubsub) listen(ctx context.Context) {
var (
notif *pq.Notification
ok bool
)
defer p.pgListener.Close()
for {
select {
case <-ctx.Done():
return
case notif, ok = <-p.pgListener.Notify:
if !ok {
return
}
}
// A nil notification can be dispatched on reconnect.
if notif == nil {
continue
}
p.listenReceive(ctx, notif)
}
}
func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
p.mut.Lock()
defer p.mut.Unlock()
listeners, ok := p.listeners[notif.Channel]
if !ok {
return
}
extra := []byte(notif.Extra)
for _, listener := range listeners {
go listener(ctx, extra)
}
}
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
case <-errCh:
return
default:
errCh <- err
close(errCh)
}
})
select {
case err := <-errCh:
if err != nil {
return nil, xerrors.Errorf("create pq listener: %w", err)
}
case <-ctx.Done():
return nil, ctx.Err()
}
pgPubsub := &pgPubsub{
db: database,
pgListener: listener,
listeners: make(map[string]map[string]Listener),
}
go pgPubsub.listen(ctx)
return pgPubsub, nil
}

View File

@ -0,0 +1,63 @@
package database
import (
"context"
"sync"
"github.com/google/uuid"
)
// memoryPubsub is an in-memory Pubsub implementation.
type memoryPubsub struct {
mut sync.RWMutex
listeners map[string]map[uuid.UUID]Listener
}
func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
m.mut.Lock()
defer m.mut.Unlock()
var listeners map[uuid.UUID]Listener
var ok bool
if listeners, ok = m.listeners[event]; !ok {
listeners = map[uuid.UUID]Listener{}
m.listeners[event] = listeners
}
var id uuid.UUID
for {
id = uuid.New()
if _, ok = listeners[id]; !ok {
break
}
}
listeners[id] = listener
return func() {
m.mut.Lock()
defer m.mut.Unlock()
listeners := m.listeners[event]
delete(listeners, id)
}, nil
}
func (m *memoryPubsub) Publish(event string, message []byte) error {
m.mut.RLock()
defer m.mut.RUnlock()
listeners, ok := m.listeners[event]
if !ok {
return nil
}
for _, listener := range listeners {
listener(context.Background(), message)
}
return nil
}
func (*memoryPubsub) Close() error {
return nil
}
func NewPubsubInMemory() Pubsub {
return &memoryPubsub{
listeners: make(map[string]map[uuid.UUID]Listener),
}
}

View File

@ -0,0 +1,35 @@
package database_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
)
func TestPubsubMemory(t *testing.T) {
t.Parallel()
t.Run("Memory", func(t *testing.T) {
t.Parallel()
pubsub := database.NewPubsubInMemory()
event := "test"
data := "testing"
messageChannel := make(chan []byte)
cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
go func() {
err = pubsub.Publish(event, []byte(data))
require.NoError(t, err)
}()
message := <-messageChannel
assert.Equal(t, string(message), data)
})
}

View File

@ -0,0 +1,70 @@
//go:build linux
package database_test
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/postgres"
)
func TestPubsub(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip()
return
}
t.Run("Postgres", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
connectionURL, close, err := postgres.Open()
require.NoError(t, err)
defer close()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
data := "testing"
messageChannel := make(chan []byte)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
go func() {
err = pubsub.Publish(event, []byte(data))
require.NoError(t, err)
}()
message := <-messageChannel
assert.Equal(t, string(message), data)
})
t.Run("PostgresCloseCancel", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
connectionURL, close, err := postgres.Open()
require.NoError(t, err)
defer close()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
cancelFunc()
})
}

View File

@ -0,0 +1,84 @@
// Code generated by sqlc. DO NOT EDIT.
package database
import (
"context"
"github.com/google/uuid"
)
type querier interface {
AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error)
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetFileByHash(ctx context.Context, hash string) (File, error)
GetOrganizationByID(ctx context.Context, id string) (Organization, error)
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
GetOrganizationsByUserID(ctx context.Context, userID string) ([]Organization, error)
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error)
GetParameterValuesByScope(ctx context.Context, arg GetParameterValuesByScopeParams) ([]ParameterValue, error)
GetProjectByID(ctx context.Context, id uuid.UUID) (Project, error)
GetProjectByOrganizationAndName(ctx context.Context, arg GetProjectByOrganizationAndNameParams) (Project, error)
GetProjectVersionByID(ctx context.Context, id uuid.UUID) (ProjectVersion, error)
GetProjectVersionByJobID(ctx context.Context, jobID uuid.UUID) (ProjectVersion, error)
GetProjectVersionByProjectIDAndName(ctx context.Context, arg GetProjectVersionByProjectIDAndNameParams) (ProjectVersion, error)
GetProjectVersionsByProjectID(ctx context.Context, dollar_1 uuid.UUID) ([]ProjectVersion, error)
GetProjectsByIDs(ctx context.Context, ids []uuid.UUID) ([]Project, error)
GetProjectsByOrganization(ctx context.Context, arg GetProjectsByOrganizationParams) ([]Project, error)
GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) (ProvisionerDaemon, error)
GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error)
GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error)
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id string) (User, error)
GetUserCount(ctx context.Context) (int64, error)
GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
GetWorkspaceAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceBuild, error)
GetWorkspaceBuildByWorkspaceIDAndName(ctx context.Context, arg GetWorkspaceBuildByWorkspaceIDAndNameParams) (WorkspaceBuild, error)
GetWorkspaceBuildByWorkspaceIDWithoutAfter(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildsByWorkspaceIDsWithoutAfter(ctx context.Context, ids []uuid.UUID) ([]WorkspaceBuild, error)
GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error)
GetWorkspaceByUserIDAndName(ctx context.Context, arg GetWorkspaceByUserIDAndNameParams) (Workspace, error)
GetWorkspaceOwnerCountsByProjectIDs(ctx context.Context, ids []uuid.UUID) ([]GetWorkspaceOwnerCountsByProjectIDsRow, error)
GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID) (WorkspaceResource, error)
GetWorkspaceResourcesByJobID(ctx context.Context, jobID uuid.UUID) ([]WorkspaceResource, error)
GetWorkspacesByProjectID(ctx context.Context, arg GetWorkspacesByProjectIDParams) ([]Workspace, error)
GetWorkspacesByUserID(ctx context.Context, arg GetWorkspacesByUserIDParams) ([]Workspace, error)
InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error)
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
InsertParameterSchema(ctx context.Context, arg InsertParameterSchemaParams) (ParameterSchema, error)
InsertParameterValue(ctx context.Context, arg InsertParameterValueParams) (ParameterValue, error)
InsertProject(ctx context.Context, arg InsertProjectParams) (Project, error)
InsertProjectVersion(ctx context.Context, arg InsertProjectVersionParams) (ProjectVersion, error)
InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error)
InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error)
InsertProvisionerJobLogs(ctx context.Context, arg InsertProvisionerJobLogsParams) ([]ProvisionerJobLog, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error)
InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error)
InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error)
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateProjectActiveVersionByID(ctx context.Context, arg UpdateProjectActiveVersionByIDParams) error
UpdateProjectDeletedByID(ctx context.Context, arg UpdateProjectDeletedByIDParams) error
UpdateProjectVersionByID(ctx context.Context, arg UpdateProjectVersionByIDParams) error
UpdateProvisionerDaemonByID(ctx context.Context, arg UpdateProvisionerDaemonByIDParams) error
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWorkspaceBuildByIDParams) error
UpdateWorkspaceDeletedByID(ctx context.Context, arg UpdateWorkspaceDeletedByIDParams) error
}
var _ querier = (*sqlQuerier)(nil)

791
coderd/database/query.sql Normal file
View File

@ -0,0 +1,791 @@
-- Database queries are generated using sqlc. See:
-- https://docs.sqlc.dev/en/latest/tutorials/getting-started-postgresql.html
--
-- Run "make gen" to generate models and query functions.
;
-- Acquires the lock for a single job that isn't started, completed,
-- canceled, and that matches an array of provisioner types.
--
-- SKIP LOCKED is used to jump over locked rows. This prevents
-- multiple provisioners from acquiring the same jobs. See:
-- https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE
-- name: AcquireProvisionerJob :one
UPDATE
provisioner_jobs
SET
started_at = @started_at,
updated_at = @started_at,
worker_id = @worker_id
WHERE
id = (
SELECT
id
FROM
provisioner_jobs AS nested
WHERE
nested.started_at IS NULL
AND nested.canceled_at IS NULL
AND nested.completed_at IS NULL
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
ORDER BY
nested.created_at FOR
UPDATE
SKIP LOCKED
LIMIT
1
) RETURNING *;
-- name: DeleteParameterValueByID :exec
DELETE FROM
parameter_values
WHERE
id = $1;
-- name: GetAPIKeyByID :one
SELECT
*
FROM
api_keys
WHERE
id = $1
LIMIT
1;
-- name: GetFileByHash :one
SELECT
*
FROM
files
WHERE
hash = $1
LIMIT
1;
-- name: GetUserByID :one
SELECT
*
FROM
users
WHERE
id = $1
LIMIT
1;
-- name: GetUserByEmailOrUsername :one
SELECT
*
FROM
users
WHERE
LOWER(username) = LOWER(@username)
OR email = @email
LIMIT
1;
-- name: GetUserCount :one
SELECT
COUNT(*)
FROM
users;
-- name: GetOrganizationByID :one
SELECT
*
FROM
organizations
WHERE
id = $1;
-- name: GetOrganizationByName :one
SELECT
*
FROM
organizations
WHERE
LOWER(name) = LOWER(@name)
LIMIT
1;
-- name: GetOrganizationsByUserID :many
SELECT
*
FROM
organizations
WHERE
id = (
SELECT
organization_id
FROM
organization_members
WHERE
user_id = $1
);
-- name: GetOrganizationMemberByUserID :one
SELECT
*
FROM
organization_members
WHERE
organization_id = $1
AND user_id = $2
LIMIT
1;
-- name: GetParameterValuesByScope :many
SELECT
*
FROM
parameter_values
WHERE
scope = $1
AND scope_id = $2;
-- name: GetParameterValueByScopeAndName :one
SELECT
*
FROM
parameter_values
WHERE
scope = $1
AND scope_id = $2
AND name = $3
LIMIT
1;
-- name: GetProjectByID :one
SELECT
*
FROM
projects
WHERE
id = $1
LIMIT
1;
-- name: GetProjectsByIDs :many
SELECT
*
FROM
projects
WHERE
id = ANY(@ids :: uuid [ ]);
-- name: GetProjectByOrganizationAndName :one
SELECT
*
FROM
projects
WHERE
organization_id = @organization_id
AND deleted = @deleted
AND LOWER(name) = LOWER(@name)
LIMIT
1;
-- name: GetProjectsByOrganization :many
SELECT
*
FROM
projects
WHERE
organization_id = $1
AND deleted = $2;
-- name: GetParameterSchemasByJobID :many
SELECT
*
FROM
parameter_schemas
WHERE
job_id = $1;
-- name: GetProjectVersionsByProjectID :many
SELECT
*
FROM
project_versions
WHERE
project_id = $1 :: uuid;
-- name: GetProjectVersionByJobID :one
SELECT
*
FROM
project_versions
WHERE
job_id = $1;
-- name: GetProjectVersionByProjectIDAndName :one
SELECT
*
FROM
project_versions
WHERE
project_id = $1
AND name = $2;
-- name: GetProjectVersionByID :one
SELECT
*
FROM
project_versions
WHERE
id = $1;
-- name: GetProvisionerLogsByIDBetween :many
SELECT
*
FROM
provisioner_job_logs
WHERE
job_id = @job_id
AND (
created_at >= @created_after
OR created_at <= @created_before
)
ORDER BY
created_at;
-- name: GetProvisionerDaemonByID :one
SELECT
*
FROM
provisioner_daemons
WHERE
id = $1;
-- name: GetProvisionerDaemons :many
SELECT
*
FROM
provisioner_daemons;
-- name: GetWorkspaceAgentByAuthToken :one
SELECT
*
FROM
workspace_agents
WHERE
auth_token = $1
ORDER BY
created_at DESC;
-- name: GetWorkspaceAgentByInstanceID :one
SELECT
*
FROM
workspace_agents
WHERE
auth_instance_id = @auth_instance_id :: text
ORDER BY
created_at DESC;
-- name: GetProvisionerJobByID :one
SELECT
*
FROM
provisioner_jobs
WHERE
id = $1;
-- name: GetProvisionerJobsByIDs :many
SELECT
*
FROM
provisioner_jobs
WHERE
id = ANY(@ids :: uuid [ ]);
-- name: GetWorkspaceByID :one
SELECT
*
FROM
workspaces
WHERE
id = $1
LIMIT
1;
-- name: GetWorkspacesByProjectID :many
SELECT
*
FROM
workspaces
WHERE
project_id = $1
AND deleted = $2;
-- name: GetWorkspacesByUserID :many
SELECT
*
FROM
workspaces
WHERE
owner_id = $1
AND deleted = $2;
-- name: GetWorkspaceByUserIDAndName :one
SELECT
*
FROM
workspaces
WHERE
owner_id = @owner_id
AND deleted = @deleted
AND LOWER(name) = LOWER(@name);
-- name: GetWorkspaceOwnerCountsByProjectIDs :many
SELECT
project_id,
COUNT(DISTINCT owner_id)
FROM
workspaces
WHERE
project_id = ANY(@ids :: uuid [ ])
GROUP BY
project_id,
owner_id;
-- name: GetWorkspaceBuildByID :one
SELECT
*
FROM
workspace_builds
WHERE
id = $1
LIMIT
1;
-- name: GetWorkspaceBuildByJobID :one
SELECT
*
FROM
workspace_builds
WHERE
job_id = $1
LIMIT
1;
-- name: GetWorkspaceBuildByWorkspaceIDAndName :one
SELECT
*
FROM
workspace_builds
WHERE
workspace_id = $1
AND name = $2;
-- name: GetWorkspaceBuildByWorkspaceID :many
SELECT
*
FROM
workspace_builds
WHERE
workspace_id = $1;
-- name: GetWorkspaceBuildByWorkspaceIDWithoutAfter :one
SELECT
*
FROM
workspace_builds
WHERE
workspace_id = $1
AND after_id IS NULL
LIMIT
1;
-- name: GetWorkspaceBuildsByWorkspaceIDsWithoutAfter :many
SELECT
*
FROM
workspace_builds
WHERE
workspace_id = ANY(@ids :: uuid [ ])
AND after_id IS NULL;
-- name: GetWorkspaceResourceByID :one
SELECT
*
FROM
workspace_resources
WHERE
id = $1;
-- name: GetWorkspaceResourcesByJobID :many
SELECT
*
FROM
workspace_resources
WHERE
job_id = $1;
-- name: GetWorkspaceAgentByResourceID :one
SELECT
*
FROM
workspace_agents
WHERE
resource_id = $1;
-- name: InsertAPIKey :one
INSERT INTO
api_keys (
id,
hashed_secret,
user_id,
application,
name,
last_used,
expires_at,
created_at,
updated_at,
login_type,
oidc_access_token,
oidc_refresh_token,
oidc_id_token,
oidc_expiry,
devurl_token
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15
) RETURNING *;
-- name: InsertFile :one
INSERT INTO
files (hash, created_at, created_by, mimetype, data)
VALUES
($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertProvisionerJobLogs :many
INSERT INTO
provisioner_job_logs
SELECT
unnest(@id :: uuid [ ]) AS id,
@job_id :: uuid AS job_id,
unnest(@created_at :: timestamptz [ ]) AS created_at,
unnest(@source :: log_source [ ]) as source,
unnest(@level :: log_level [ ]) as level,
unnest(@output :: varchar(1024) [ ]) as output RETURNING *;
-- name: InsertOrganization :one
INSERT INTO
organizations (id, name, description, created_at, updated_at)
VALUES
($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertOrganizationMember :one
INSERT INTO
organization_members (
organization_id,
user_id,
created_at,
updated_at,
roles
)
VALUES
($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertParameterValue :one
INSERT INTO
parameter_values (
id,
name,
created_at,
updated_at,
scope,
scope_id,
source_scheme,
source_value,
destination_scheme
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;
-- name: InsertProject :one
INSERT INTO
projects (
id,
created_at,
updated_at,
organization_id,
name,
provisioner,
active_version_id
)
VALUES
($1, $2, $3, $4, $5, $6, $7) RETURNING *;
-- name: InsertWorkspaceResource :one
INSERT INTO
workspace_resources (
id,
created_at,
job_id,
transition,
address,
type,
name,
agent_id
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
-- name: InsertProjectVersion :one
INSERT INTO
project_versions (
id,
project_id,
organization_id,
created_at,
updated_at,
name,
description,
job_id
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
-- name: InsertParameterSchema :one
INSERT INTO
parameter_schemas (
id,
created_at,
job_id,
name,
description,
default_source_scheme,
default_source_value,
allow_override_source,
default_destination_scheme,
allow_override_destination,
default_refresh,
redisplay_value,
validation_error,
validation_condition,
validation_type_system,
validation_value_type
)
VALUES
(
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15,
$16
) RETURNING *;
-- name: InsertProvisionerDaemon :one
INSERT INTO
provisioner_daemons (id, created_at, organization_id, name, provisioners)
VALUES
($1, $2, $3, $4, $5) RETURNING *;
-- name: InsertProvisionerJob :one
INSERT INTO
provisioner_jobs (
id,
created_at,
updated_at,
organization_id,
initiator_id,
provisioner,
storage_method,
storage_source,
type,
input
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
-- name: InsertUser :one
INSERT INTO
users (
id,
email,
name,
login_type,
revoked,
hashed_password,
created_at,
updated_at,
username
)
VALUES
($1, $2, $3, $4, false, $5, $6, $7, $8) RETURNING *;
-- name: InsertWorkspace :one
INSERT INTO
workspaces (
id,
created_at,
updated_at,
owner_id,
project_id,
name
)
VALUES
($1, $2, $3, $4, $5, $6) RETURNING *;
-- name: InsertWorkspaceAgent :one
INSERT INTO
workspace_agents (
id,
created_at,
updated_at,
resource_id,
auth_token,
auth_instance_id,
environment_variables,
startup_script,
instance_metadata,
resource_metadata
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
-- name: InsertWorkspaceBuild :one
INSERT INTO
workspace_builds (
id,
created_at,
updated_at,
workspace_id,
project_version_id,
before_id,
name,
transition,
initiator,
job_id,
provisioner_state
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *;
-- name: UpdateAPIKeyByID :exec
UPDATE
api_keys
SET
last_used = $2,
expires_at = $3,
oidc_access_token = $4,
oidc_refresh_token = $5,
oidc_expiry = $6
WHERE
id = $1;
-- name: UpdateProjectActiveVersionByID :exec
UPDATE
projects
SET
active_version_id = $2
WHERE
id = $1;
-- name: UpdateProjectDeletedByID :exec
UPDATE
projects
SET
deleted = $2
WHERE
id = $1;
-- name: UpdateProjectVersionByID :exec
UPDATE
project_versions
SET
project_id = $2,
updated_at = $3
WHERE
id = $1;
-- name: UpdateProvisionerDaemonByID :exec
UPDATE
provisioner_daemons
SET
updated_at = $2,
provisioners = $3
WHERE
id = $1;
-- name: UpdateProvisionerJobByID :exec
UPDATE
provisioner_jobs
SET
updated_at = $2
WHERE
id = $1;
-- name: UpdateProvisionerJobWithCancelByID :exec
UPDATE
provisioner_jobs
SET
canceled_at = $2
WHERE
id = $1;
-- name: UpdateProvisionerJobWithCompleteByID :exec
UPDATE
provisioner_jobs
SET
updated_at = $2,
completed_at = $3,
canceled_at = $4,
error = $5
WHERE
id = $1;
-- name: UpdateWorkspaceDeletedByID :exec
UPDATE
workspaces
SET
deleted = $2
WHERE
id = $1;
-- name: UpdateWorkspaceAgentConnectionByID :exec
UPDATE
workspace_agents
SET
first_connected_at = $2,
last_connected_at = $3,
disconnected_at = $4
WHERE
id = $1;
-- name: UpdateWorkspaceBuildByID :exec
UPDATE
workspace_builds
SET
updated_at = $2,
after_id = $3,
provisioner_state = $4
WHERE
id = $1;

2718
coderd/database/query.sql.go Normal file

File diff suppressed because it is too large Load Diff

29
coderd/database/sqlc.yaml Normal file
View File

@ -0,0 +1,29 @@
# sqlc is used to generate types from sql schema language.
# It was chosen to ensure type-safety when interacting with
# the database.
version: "1"
packages:
- name: "database"
path: "."
queries: "./query.sql"
schema: "./dump.sql"
engine: "postgresql"
emit_interface: true
emit_json_tags: true
emit_db_tags: true
# We replace the generated db file with our own
# to add support for transactions. This file is
# deleted after generation.
output_db_file_name: db_tmp.go
overrides:
- db_type: citext
go_type: string
rename:
api_key: APIKey
login_type_oidc: LoginTypeOIDC
oidc_access_token: OIDCAccessToken
oidc_expiry: OIDCExpiry
oidc_id_token: OIDCIDToken
oidc_refresh_token: OIDCRefreshToken
parameter_type_system_hcl: ParameterTypeSystemHCL
userstatus: UserStatus

8
coderd/database/time.go Normal file
View File

@ -0,0 +1,8 @@
package database
import "time"
// Now returns a standardized timezone used for database resources.
func Now() time.Time {
return time.Now().UTC()
}

View File

@ -12,10 +12,10 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
func (api *api) postFile(rw http.ResponseWriter, r *http.Request) {

117
coderd/httpapi/httpapi.go Normal file
View File

@ -0,0 +1,117 @@
package httpapi
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"regexp"
"strings"
"github.com/go-playground/validator/v10"
)
var (
validate *validator.Validate
usernameRegex = regexp.MustCompile("^[a-zA-Z0-9]+(?:-[a-zA-Z0-9]+)*$")
)
// This init is used to create a validator and register validation-specific
// functionality for the HTTP API.
//
// A single validator instance is used, because it caches struct parsing.
func init() {
validate = validator.New()
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
err := validate.RegisterValidation("username", func(fl validator.FieldLevel) bool {
f := fl.Field().Interface()
str, ok := f.(string)
if !ok {
return false
}
if len(str) > 32 {
return false
}
if len(str) < 1 {
return false
}
return usernameRegex.MatchString(str)
})
if err != nil {
panic(err)
}
}
// Response represents a generic HTTP response.
type Response struct {
Message string `json:"message" validate:"required"`
Errors []Error `json:"errors,omitempty" validate:"required"`
}
// Error represents a scoped error to a user input.
type Error struct {
Field string `json:"field" validate:"required"`
Code string `json:"code" validate:"required"`
}
// Write outputs a standardized format to an HTTP response body.
func Write(rw http.ResponseWriter, status int, response Response) {
buf := &bytes.Buffer{}
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(true)
err := enc.Encode(response)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
rw.Header().Set("Content-Type", "application/json; charset=utf-8")
rw.WriteHeader(status)
_, err = rw.Write(buf.Bytes())
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
}
// Read decodes JSON from the HTTP request into the value provided.
// It uses go-validator to validate the incoming request body.
func Read(rw http.ResponseWriter, r *http.Request, value interface{}) bool {
err := json.NewDecoder(r.Body).Decode(value)
if err != nil {
Write(rw, http.StatusBadRequest, Response{
Message: fmt.Sprintf("read body: %s", err.Error()),
})
return false
}
err = validate.Struct(value)
var validationErrors validator.ValidationErrors
if errors.As(err, &validationErrors) {
apiErrors := make([]Error, 0, len(validationErrors))
for _, validationError := range validationErrors {
apiErrors = append(apiErrors, Error{
Field: validationError.Field(),
Code: validationError.Tag(),
})
}
Write(rw, http.StatusBadRequest, Response{
Message: "Validation failed",
Errors: apiErrors,
})
return false
}
if err != nil {
Write(rw, http.StatusInternalServerError, Response{
Message: fmt.Sprintf("validation: %s", err.Error()),
})
return false
}
return true
}

View File

@ -0,0 +1,144 @@
package httpapi_test
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/httpapi"
)
func TestWrite(t *testing.T) {
t.Parallel()
t.Run("NoErrors", func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Message: "wow",
})
var m map[string]interface{}
err := json.NewDecoder(rw.Body).Decode(&m)
require.NoError(t, err)
_, ok := m["errors"]
require.False(t, ok)
})
}
func TestRead(t *testing.T) {
t.Parallel()
t.Run("EmptyStruct", func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
v := struct{}{}
require.True(t, httpapi.Read(rw, r, &v))
})
t.Run("NoBody", func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", nil)
var v json.RawMessage
require.False(t, httpapi.Read(rw, r, v))
})
t.Run("Validate", func(t *testing.T) {
t.Parallel()
type toValidate struct {
Value string `json:"value" validate:"required"`
}
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", bytes.NewBufferString(`{"value":"hi"}`))
var validate toValidate
require.True(t, httpapi.Read(rw, r, &validate))
require.Equal(t, validate.Value, "hi")
})
t.Run("ValidateFailure", func(t *testing.T) {
t.Parallel()
type toValidate struct {
Value string `json:"value" validate:"required"`
}
rw := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/", bytes.NewBufferString("{}"))
var validate toValidate
require.False(t, httpapi.Read(rw, r, &validate))
var v httpapi.Response
err := json.NewDecoder(rw.Body).Decode(&v)
require.NoError(t, err)
require.Len(t, v.Errors, 1)
require.Equal(t, v.Errors[0].Field, "value")
require.Equal(t, v.Errors[0].Code, "required")
})
}
func TestReadUsername(t *testing.T) {
t.Parallel()
// Tests whether usernames are valid or not.
testCases := []struct {
Username string
Valid bool
}{
{"1", true},
{"12", true},
{"123", true},
{"12345678901234567890", true},
{"123456789012345678901", true},
{"a", true},
{"a1", true},
{"a1b2", true},
{"a1b2c3d4e5f6g7h8i9j0", true},
{"a1b2c3d4e5f6g7h8i9j0k", true},
{"aa", true},
{"abc", true},
{"abcdefghijklmnopqrst", true},
{"abcdefghijklmnopqrstu", true},
{"wow-test", true},
{"", false},
{" ", false},
{" a", false},
{" a ", false},
{" 1", false},
{"1 ", false},
{" aa", false},
{"aa ", false},
{" 12", false},
{"12 ", false},
{" a1", false},
{"a1 ", false},
{" abcdefghijklmnopqrstu", false},
{"abcdefghijklmnopqrstu ", false},
{" 123456789012345678901", false},
{" a1b2c3d4e5f6g7h8i9j0k", false},
{"a1b2c3d4e5f6g7h8i9j0k ", false},
{"bananas_wow", false},
{"test--now", false},
{"123456789012345678901234567890123", false},
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
{"123456789012345678901234567890123123456789012345678901234567890123", false},
}
type toValidate struct {
Username string `json:"username" validate:"username"`
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.Username, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
data, err := json.Marshal(toValidate{testCase.Username})
require.NoError(t, err)
r := httptest.NewRequest("POST", "/", bytes.NewBuffer(data))
var validate toValidate
require.Equal(t, httpapi.Read(rw, r, &validate), testCase.Valid)
})
}
}

167
coderd/httpmw/apikey.go Normal file
View File

@ -0,0 +1,167 @@
package httpmw
import (
"context"
"crypto/sha256"
"crypto/subtle"
"database/sql"
"errors"
"fmt"
"net/http"
"strings"
"time"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
// AuthCookie represents the name of the cookie the API key is stored in.
const AuthCookie = "session_token"
// OAuth2Config contains a subset of functions exposed from oauth2.Config.
// It is abstracted for simple testing.
type OAuth2Config interface {
TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource
}
type apiKeyContextKey struct{}
// APIKey returns the API key from the ExtractAPIKey handler.
func APIKey(r *http.Request) database.APIKey {
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
if !ok {
panic("developer error: apikey middleware not provided")
}
return apiKey
}
// ExtractAPIKey requires authentication using a valid API key.
// It handles extending an API key if it comes close to expiry,
// updating the last used time in the database.
func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(AuthCookie)
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("%q cookie must be provided", AuthCookie),
})
return
}
parts := strings.Split(cookie.Value, "-")
// APIKeys are formatted: ID-SECRET
if len(parts) != 2 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key format", AuthCookie),
})
return
}
keyID := parts[0]
keySecret := parts[1]
// Ensuring key lengths are valid.
if len(keyID) != 10 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key id", AuthCookie),
})
return
}
if len(keySecret) != 22 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("invalid %q cookie api key secret", AuthCookie),
})
return
}
key, err := db.GetAPIKeyByID(r.Context(), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "api key is invalid",
})
return
}
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get api key by id: %s", err.Error()),
})
return
}
hashed := sha256.Sum256([]byte(keySecret))
// Checking to see if the secret is valid.
if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "api key secret is invalid",
})
return
}
now := database.Now()
// Tracks if the API key has properties updated!
changed := false
if key.LoginType == database.LoginTypeOIDC {
// Check if the OIDC token is expired!
if key.OIDCExpiry.Before(now) && !key.OIDCExpiry.IsZero() {
// If it is, let's refresh it from the provided config!
token, err := oauthConfig.TokenSource(r.Context(), &oauth2.Token{
AccessToken: key.OIDCAccessToken,
RefreshToken: key.OIDCRefreshToken,
Expiry: key.OIDCExpiry,
}).Token()
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("couldn't refresh expired oauth token: %s", err.Error()),
})
return
}
key.OIDCAccessToken = token.AccessToken
key.OIDCRefreshToken = token.RefreshToken
key.OIDCExpiry = token.Expiry
key.ExpiresAt = token.Expiry
changed = true
}
}
// Checking if the key is expired.
if key.ExpiresAt.Before(now) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("api key expired at %q", key.ExpiresAt.String()),
})
return
}
// Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour {
key.LastUsed = now
changed = true
}
// Only update the ExpiresAt once an hour to prevent database spam.
// We extend the ExpiresAt to reduce reauthentication.
apiKeyLifetime := 24 * time.Hour
if key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour {
key.ExpiresAt = now.Add(apiKeyLifetime)
changed = true
}
if changed {
err := db.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
ID: key.ID,
ExpiresAt: key.ExpiresAt,
LastUsed: key.LastUsed,
OIDCAccessToken: key.OIDCAccessToken,
OIDCRefreshToken: key.OIDCRefreshToken,
OIDCExpiry: key.OIDCExpiry,
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("api key couldn't update: %s", err.Error()),
})
return
}
}
ctx := context.WithValue(r.Context(), apiKeyContextKey{}, key)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,375 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/cryptorand"
)
func randomAPIKeyParts() (id string, secret string) {
id, _ = cryptorand.String(10)
secret, _ = cryptorand.String(22)
return id, secret
}
func TestAPIKey(t *testing.T) {
t.Parallel()
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Only called if the API key passes through the handler.
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Message: "it worked!",
})
})
t.Run("NoCookie", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidFormat", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: "test-wow-hello",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidIDLength", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: "test-wow",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidSecretLength", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: "testtestid-wow",
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("InvalidSecret", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
// Use a different secret so they don't match!
hashed := sha256.Sum256([]byte("differentsecret"))
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Expired", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Valid", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Checks that it exists on the context!
_ = httpmw.APIKey(r)
httpapi.Write(rw, http.StatusOK, httpapi.Response{
Message: "it worked!",
})
})).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("ValidUpdateLastUsed", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now().AddDate(0, 0, -1),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.NotEqual(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("ValidUpdateExpiry", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.NotEqual(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LastUsed: database.Now(),
ExpiresAt: database.Now().AddDate(0, 0, 1),
})
require.NoError(t, err)
httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt)
})
t.Run("OIDCRefresh", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LoginType: database.LoginTypeOIDC,
LastUsed: database.Now(),
OIDCExpiry: database.Now().AddDate(0, 0, -1),
})
require.NoError(t, err)
token := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: database.Now().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKey(db, &oauth2Config{
tokenSource: &oauth2TokenSource{
token: func() (*oauth2.Token, error) {
return token, nil
},
},
})(successHandler).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err)
require.Equal(t, sentAPIKey.LastUsed, gotAPIKey.LastUsed)
require.Equal(t, token.Expiry, gotAPIKey.ExpiresAt)
require.Equal(t, token.AccessToken, gotAPIKey.OIDCAccessToken)
})
}
type oauth2Config struct {
tokenSource *oauth2TokenSource
}
func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource {
return o.tokenSource
}
type oauth2TokenSource struct {
token func() (*oauth2.Token, error)
}
func (o *oauth2TokenSource) Token() (*oauth2.Token, error) {
return o.token()
}

30
coderd/httpmw/httpmw.go Normal file
View File

@ -0,0 +1,30 @@
package httpmw
import (
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/coder/coder/coderd/httpapi"
)
// parseUUID consumes a url parameter and parses it as a UUID.
func parseUUID(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) {
rawID := chi.URLParam(r, param)
if rawID == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("%q must be provided", param),
})
return uuid.UUID{}, false
}
parsed, err := uuid.Parse(rawID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("%q must be a uuid", param),
})
return uuid.UUID{}, false
}
return parsed, true
}

View File

@ -0,0 +1,86 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type organizationParamContextKey struct{}
type organizationMemberParamContextKey struct{}
// OrganizationParam returns the organization from the ExtractOrganizationParam handler.
func OrganizationParam(r *http.Request) database.Organization {
organization, ok := r.Context().Value(organizationParamContextKey{}).(database.Organization)
if !ok {
panic("developer error: organization param middleware not provided")
}
return organization
}
// OrganizationMemberParam returns the organization membership that allowed the query
// from the ExtractOrganizationParam handler.
func OrganizationMemberParam(r *http.Request) database.OrganizationMember {
organizationMember, ok := r.Context().Value(organizationMemberParamContextKey{}).(database.OrganizationMember)
if !ok {
panic("developer error: organization param middleware not provided")
}
return organizationMember
}
// ExtractOrganizationParam grabs an organization and user membership from the "organization" URL parameter.
// This middleware requires the API key middleware higher in the call stack for authentication.
func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
organizationID := chi.URLParam(r, "organization")
if organizationID == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "organization must be provided",
})
return
}
organization, err := db.GetOrganizationByID(r.Context(), organizationID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("organization %q does not exist", organizationID),
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get organization: %s", err.Error()),
})
return
}
apiKey := APIKey(r)
organizationMember, err := db.GetOrganizationMemberByUserID(r.Context(), database.GetOrganizationMemberByUserIDParams{
OrganizationID: organization.ID,
UserID: apiKey.UserID,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "not a member of the organization",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get organization member: %s", err.Error()),
})
return
}
ctx := context.WithValue(r.Context(), organizationParamContextKey{}, organization)
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, organizationMember)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,165 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/cryptorand"
)
func TestOrganizationParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.User) {
var (
id, secret = randomAPIKeyParts()
r = httptest.NewRequest("GET", "/", nil)
hashed = sha256.Sum256([]byte(secret))
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID, err := cryptorand.String(16)
require.NoError(t, err)
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
return r, user
}
t.Run("None", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, _ = setupAuthentication(db)
rtr = chi.NewRouter()
)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, _ = setupAuthentication(db)
rtr = chi.NewRouter()
)
chi.RouteContext(r.Context()).URLParams.Add("organization", "nothin")
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("NotInOrganization", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, _ = setupAuthentication(db)
rtr = chi.NewRouter()
)
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.NewString(),
Name: "test",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", nil)
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Success", func(t *testing.T) {
t.Parallel()
var (
db = databasefake.New()
rw = httptest.NewRecorder()
r, user = setupAuthentication(db)
rtr = chi.NewRouter()
)
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.NewString(),
Name: "test",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("organization", organization.ID)
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.OrganizationParam(r)
_ = httpmw.OrganizationMemberParam(r)
rw.WriteHeader(http.StatusOK)
})
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -0,0 +1,53 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type projectParamContextKey struct{}
// ProjectParam returns the project from the ExtractProjectParam handler.
func ProjectParam(r *http.Request) database.Project {
project, ok := r.Context().Value(projectParamContextKey{}).(database.Project)
if !ok {
panic("developer error: project param middleware not provided")
}
return project
}
// ExtractProjectParam grabs a project from the "project" URL parameter.
func ExtractProjectParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
projectID, parsed := parseUUID(rw, r, "project")
if !parsed {
return
}
project, err := db.GetProjectByID(r.Context(), projectID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("project %q does not exist", projectID),
})
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get project: %s", err),
})
return
}
ctx := context.WithValue(r.Context(), projectParamContextKey{}, project)
chi.RouteContext(ctx).URLParams.Add("organization", project.OrganizationID)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,159 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/cryptorand"
)
func TestProjectParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.Organization) {
var (
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
)
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID, err := cryptorand.String(16)
require.NoError(t, err)
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
orgID, err := cryptorand.String(16)
require.NoError(t, err)
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: orgID,
Name: "banana",
Description: "wowie",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: orgID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
ctx := chi.NewRouteContext()
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, organization
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractProjectParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractProjectParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("project", uuid.NewString())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("BadUUID", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractProjectParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("project", "not-a-uuid")
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("Project", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractProjectParam(db),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.ProjectParam(r)
rw.WriteHeader(http.StatusOK)
})
r, org := setupAuthentication(db)
project, err := db.InsertProject(context.Background(), database.InsertProjectParams{
ID: uuid.New(),
OrganizationID: org.ID,
Name: "moo",
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("project", project.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -0,0 +1,54 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type projectVersionParamContextKey struct{}
// ProjectVersionParam returns the project version from the ExtractProjectVersionParam handler.
func ProjectVersionParam(r *http.Request) database.ProjectVersion {
projectVersion, ok := r.Context().Value(projectVersionParamContextKey{}).(database.ProjectVersion)
if !ok {
panic("developer error: project version param middleware not provided")
}
return projectVersion
}
// ExtractProjectVersionParam grabs project version from the "projectversion" URL parameter.
func ExtractProjectVersionParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
projectVersionID, parsed := parseUUID(rw, r, "projectversion")
if !parsed {
return
}
projectVersion, err := db.GetProjectVersionByID(r.Context(), projectVersionID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("project version %q does not exist", projectVersionID),
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get project version: %s", err.Error()),
})
return
}
ctx := context.WithValue(r.Context(), projectVersionParamContextKey{}, projectVersion)
chi.RouteContext(ctx).URLParams.Add("organization", projectVersion.OrganizationID)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,150 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/cryptorand"
)
func TestProjectVersionParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.Project) {
var (
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
)
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID, err := cryptorand.String(16)
require.NoError(t, err)
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
orgID, err := cryptorand.String(16)
require.NoError(t, err)
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: orgID,
Name: "banana",
Description: "wowie",
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: orgID,
UserID: user.ID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
project, err := db.InsertProject(context.Background(), database.InsertProjectParams{
ID: uuid.New(),
OrganizationID: organization.ID,
Name: "moo",
})
require.NoError(t, err)
ctx := chi.NewRouteContext()
ctx.URLParams.Add("organization", organization.Name)
ctx.URLParams.Add("project", project.Name)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, project
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractProjectVersionParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractProjectVersionParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("projectversion", uuid.NewString())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("ProjectVersion", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractProjectVersionParam(db),
httpmw.ExtractOrganizationParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.ProjectVersionParam(r)
rw.WriteHeader(http.StatusOK)
})
r, project := setupAuthentication(db)
projectVersion, err := db.InsertProjectVersion(context.Background(), database.InsertProjectVersionParams{
ID: uuid.New(),
OrganizationID: project.OrganizationID,
Name: "moo",
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("projectversion", projectVersion.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -0,0 +1,54 @@
package httpmw
import (
"context"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type userParamContextKey struct{}
// UserParam returns the user from the ExtractUserParam handler.
func UserParam(r *http.Request) database.User {
user, ok := r.Context().Value(userParamContextKey{}).(database.User)
if !ok {
panic("developer error: user parameter middleware not provided")
}
return user
}
// ExtractUserParam extracts a user from an ID/username in the {user} URL parameter.
func ExtractUserParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "user")
if userID == "" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "user id or name must be provided",
})
return
}
apiKey := APIKey(r)
if apiKey.UserID != userID && userID != "me" {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "getting non-personal users isn't supported yet",
})
return
}
user, err := db.GetUserByID(r.Context(), apiKey.UserID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get user: %s", err.Error()),
})
}
ctx := context.WithValue(r.Context(), userParamContextKey{}, user)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,104 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
)
func TestUserParam(t *testing.T) {
t.Parallel()
setup := func(t *testing.T) (database.Store, *httptest.ResponseRecorder, *http.Request) {
var (
db = databasefake.New()
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
_, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: "bananas",
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: "bananas",
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
return db, rw, r
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotMe", func(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
routeContext := chi.NewRouteContext()
routeContext.URLParams.Add("user", "ben")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("Me", func(t *testing.T) {
t.Parallel()
db, rw, r := setup(t)
httpmw.ExtractAPIKey(db, nil)(http.HandlerFunc(func(rw http.ResponseWriter, returnedRequest *http.Request) {
r = returnedRequest
})).ServeHTTP(rw, r)
routeContext := chi.NewRouteContext()
routeContext.URLParams.Add("user", "me")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
httpmw.ExtractUserParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.UserParam(r)
rw.WriteHeader(http.StatusOK)
})).ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -0,0 +1,65 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type workspaceAgentContextKey struct{}
// WorkspaceAgent returns the workspace agent from the ExtractAgent handler.
func WorkspaceAgent(r *http.Request) database.WorkspaceAgent {
user, ok := r.Context().Value(workspaceAgentContextKey{}).(database.WorkspaceAgent)
if !ok {
panic("developer error: agent middleware not provided")
}
return user
}
// ExtractWorkspaceAgent requires authentication using a valid agent token.
func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(AuthCookie)
if err != nil {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: fmt.Sprintf("%q cookie must be provided", AuthCookie),
})
return
}
token, err := uuid.Parse(cookie.Value)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("parse token: %s", err),
})
return
}
agent, err := db.GetWorkspaceAgentByAuthToken(r.Context(), token)
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "agent token is invalid",
})
return
}
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace agent: %s", err),
})
return
}
ctx := context.WithValue(r.Context(), workspaceAgentContextKey{}, agent)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,73 @@
package httpmw_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
)
func TestWorkspaceAgent(t *testing.T) {
t.Parallel()
setup := func(db database.Store) (*http.Request, uuid.UUID) {
token := uuid.New()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: token.String(),
})
return r, token
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceAgent(db),
)
rtr.Get("/", nil)
r, _ := setup(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Found", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceAgent(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceAgent(r)
rw.WriteHeader(http.StatusOK)
})
r, token := setup(db)
_, err := db.InsertWorkspaceAgent(context.Background(), database.InsertWorkspaceAgentParams{
ID: uuid.New(),
AuthToken: token,
})
require.NoError(t, err)
require.NoError(t, err)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -0,0 +1,56 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type workspaceBuildParamContextKey struct{}
// WorkspaceBuildParam returns the workspace build from the ExtractWorkspaceBuildParam handler.
func WorkspaceBuildParam(r *http.Request) database.WorkspaceBuild {
workspaceBuild, ok := r.Context().Value(workspaceBuildParamContextKey{}).(database.WorkspaceBuild)
if !ok {
panic("developer error: workspace build param middleware not provided")
}
return workspaceBuild
}
// ExtractWorkspaceBuildParam grabs workspace build from the "workspacebuild" URL parameter.
func ExtractWorkspaceBuildParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
workspaceBuildID, parsed := parseUUID(rw, r, "workspacebuild")
if !parsed {
return
}
workspaceBuild, err := db.GetWorkspaceBuildByID(r.Context(), workspaceBuildID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("workspace build %q does not exist", workspaceBuildID),
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build: %s", err.Error()),
})
return
}
ctx := context.WithValue(r.Context(), workspaceBuildParamContextKey{}, workspaceBuild)
// This injects the "workspace" parameter, because it's expected the consumer
// will want to use the Workspace middleware to ensure the caller owns the workspace.
chi.RouteContext(ctx).URLParams.Add("workspace", workspaceBuild.WorkspaceID.String())
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,134 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/cryptorand"
)
func TestWorkspaceBuildParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.Workspace) {
var (
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
)
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID, err := cryptorand.String(16)
require.NoError(t, err)
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
ID: uuid.New(),
ProjectID: uuid.New(),
OwnerID: user.ID,
Name: "potato",
})
require.NoError(t, err)
ctx := chi.NewRouteContext()
ctx.URLParams.Add("user", userID)
ctx.URLParams.Add("workspace", workspace.Name)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, workspace
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractWorkspaceBuildParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractWorkspaceBuildParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("workspacebuild", uuid.NewString())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("WorkspaceBuild", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractWorkspaceBuildParam(db),
httpmw.ExtractWorkspaceParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceBuildParam(r)
rw.WriteHeader(http.StatusOK)
})
r, workspace := setupAuthentication(db)
workspaceBuild, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{
ID: uuid.New(),
WorkspaceID: workspace.ID,
Name: "moo",
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("workspacebuild", workspaceBuild.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -0,0 +1,59 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type workspaceParamContextKey struct{}
// WorkspaceParam returns the workspace from the ExtractWorkspaceParam handler.
func WorkspaceParam(r *http.Request) database.Workspace {
workspace, ok := r.Context().Value(workspaceParamContextKey{}).(database.Workspace)
if !ok {
panic("developer error: workspace param middleware not provided")
}
return workspace
}
// ExtractWorkspaceParam grabs a workspace from the "workspace" URL parameter.
func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
workspaceID, parsed := parseUUID(rw, r, "workspace")
if !parsed {
return
}
workspace, err := db.GetWorkspaceByID(r.Context(), workspaceID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: fmt.Sprintf("workspace %q does not exist", workspaceID),
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace: %s", err.Error()),
})
return
}
apiKey := APIKey(r)
if apiKey.UserID != workspace.OwnerID {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "getting non-personal workspaces isn't supported",
})
return
}
ctx := context.WithValue(r.Context(), workspaceParamContextKey{}, workspace)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,148 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/cryptorand"
)
func TestWorkspaceParam(t *testing.T) {
t.Parallel()
setup := func(db database.Store) (*http.Request, database.User) {
var (
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
)
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID, err := cryptorand.String(16)
require.NoError(t, err)
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
ctx := chi.NewRouteContext()
ctx.URLParams.Add("user", "me")
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, user
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractWorkspaceParam(db))
rtr.Get("/", nil)
r, _ := setup(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractWorkspaceParam(db))
rtr.Get("/", nil)
r, _ := setup(db)
chi.RouteContext(r.Context()).URLParams.Add("workspace", uuid.NewString())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("NonPersonal", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractWorkspaceParam(db),
)
rtr.Get("/", nil)
r, _ := setup(db)
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
ID: uuid.New(),
OwnerID: "not-me",
Name: "hello",
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("workspace", workspace.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
})
t.Run("Found", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractWorkspaceParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceParam(r)
rw.WriteHeader(http.StatusOK)
})
r, user := setup(db)
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
ID: uuid.New(),
OwnerID: user.ID,
Name: "hello",
})
require.NoError(t, err)
chi.RouteContext(r.Context()).URLParams.Add("workspace", workspace.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -0,0 +1,76 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type workspaceResourceParamContextKey struct{}
// ProvisionerJobParam returns the project from the ExtractProjectParam handler.
func WorkspaceResourceParam(r *http.Request) database.WorkspaceResource {
resource, ok := r.Context().Value(workspaceResourceParamContextKey{}).(database.WorkspaceResource)
if !ok {
panic("developer error: workspace resource param middleware not provided")
}
return resource
}
// ExtractWorkspaceResourceParam grabs a workspace resource from the "provisionerjob" URL parameter.
func ExtractWorkspaceResourceParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
resourceUUID, parsed := parseUUID(rw, r, "workspaceresource")
if !parsed {
return
}
resource, err := db.GetWorkspaceResourceByID(r.Context(), resourceUUID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: "resource doesn't exist with that id",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get provisioner resource: %s", err),
})
return
}
job, err := db.GetProvisionerJobByID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get provisioner job: %s", err),
})
return
}
if job.Type != database.ProvisionerJobTypeWorkspaceBuild {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Workspace resources can only be fetched for builds.",
})
return
}
build, err := db.GetWorkspaceBuildByJobID(r.Context(), job.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build: %s", err),
})
return
}
ctx := context.WithValue(r.Context(), workspaceResourceParamContextKey{}, resource)
ctx = context.WithValue(ctx, workspaceBuildParamContextKey{}, build)
chi.RouteContext(ctx).URLParams.Add("workspace", build.WorkspaceID.String())
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,122 @@
package httpmw_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
)
func TestWorkspaceResourceParam(t *testing.T) {
t.Parallel()
setup := func(db database.Store, jobType database.ProvisionerJobType) (*http.Request, database.WorkspaceResource) {
r := httptest.NewRequest("GET", "/", nil)
job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
ID: uuid.New(),
Type: jobType,
})
require.NoError(t, err)
workspaceBuild, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{
ID: uuid.New(),
JobID: job.ID,
})
require.NoError(t, err)
resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{
ID: uuid.New(),
JobID: job.ID,
})
require.NoError(t, err)
ctx := chi.NewRouteContext()
ctx.URLParams.Add("workspacebuild", workspaceBuild.ID.String())
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, resource
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractWorkspaceResourceParam(db))
rtr.Get("/", nil)
r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceResourceParam(db),
)
rtr.Get("/", nil)
r, _ := setup(db, database.ProvisionerJobTypeWorkspaceBuild)
chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", uuid.NewString())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("FoundBadJobType", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceResourceParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceResourceParam(r)
rw.WriteHeader(http.StatusOK)
})
r, job := setup(db, database.ProvisionerJobTypeProjectVersionImport)
chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("Found", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceResourceParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceResourceParam(r)
rw.WriteHeader(http.StatusOK)
})
r, job := setup(db, database.ProvisionerJobTypeWorkspaceBuild)
chi.RouteContext(r.Context()).URLParams.Add("workspaceresource", job.ID.String())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -12,10 +12,10 @@ import (
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
func (*api) organization(rw http.ResponseWriter, r *http.Request) {

View File

@ -9,8 +9,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/provisioner/echo"
)

View File

@ -8,7 +8,7 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/database"
"github.com/coder/coder/coderd/database"
)
// ComputeScope targets identifiers to pull parameters from.

View File

@ -7,10 +7,10 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/database/databasefake"
)
func TestCompute(t *testing.T) {

View File

@ -10,9 +10,9 @@ import (
"github.com/go-chi/render"
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
func (api *api) postParameter(rw http.ResponseWriter, r *http.Request) {

View File

@ -8,8 +8,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
)
func TestPostParameter(t *testing.T) {

View File

@ -10,10 +10,10 @@ import (
"github.com/go-chi/render"
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
// Returns a single project.

View File

@ -8,11 +8,11 @@ import (
"github.com/go-chi/render"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
func (api *api) projectVersion(rw http.ResponseWriter, r *http.Request) {

View File

@ -23,9 +23,9 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"

View File

@ -15,9 +15,9 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
)
// Returns provisioner logs based on query parameters.

View File

@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/database"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)

View File

@ -15,12 +15,12 @@ import (
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/userpassword"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
// Returns whether the initial user has been created or not.

View File

@ -9,8 +9,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/httpmw"
)
func TestFirstUser(t *testing.T) {

View File

@ -7,10 +7,10 @@ import (
"github.com/go-chi/render"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
func (api *api) workspaceBuild(rw http.ResponseWriter, r *http.Request) {

View File

@ -9,9 +9,9 @@ import (
"github.com/go-chi/render"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/mitchellh/mapstructure"
)

View File

@ -15,10 +15,10 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"

View File

@ -13,10 +13,10 @@ import (
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/httpapi"
"github.com/coder/coder/httpmw"
)
func (api *api) workspace(rw http.ResponseWriter, r *http.Request) {

View File

@ -9,8 +9,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)