feat: Add workspace application support (#1773)

* feat: Add app support

This adds apps as a property to a workspace agent.

The resource is added to the Terraform provider here:
https://github.com/coder/terraform-provider-coder/pull/17

Apps will be opened in the dashboard or via the CLI
with `coder open <name>`. If `command` is specified, a
terminal will appear locally and in the web. If `target`
is specified, the browser will open to an exposed instance
of that target.

* Compare fields in apps test

* Update Terraform provider to use relative path

* Add some basic structure for routing

* chore: Remove interface from coderd and lift API surface

Abstracting coderd into an interface added misdirection because
the interface was never intended to be fulfilled outside of a single
implementation.

This lifts the abstraction, and attaches all handlers to a root struct
named `*coderd.API`.

* Add basic proxy logic

* Add proxying based on path

* Add app proxying for wildcards

* Add wsconncache

* fix: Race when writing to a closed pipe

This is such an intermittent race it's difficult to track,
but regardless this is an improvement to the code.

* fix: Race when writing to a closed pipe

This is such an intermittent race it's difficult to track,
but regardless this is an improvement to the code.

* fix: Race when writing to a closed pipe

This is such an intermittent race it's difficult to track,
but regardless this is an improvement to the code.

* fix: Race when writing to a closed pipe

This is such an intermittent race it's difficult to track,
but regardless this is an improvement to the code.

* Add workspace route proxying endpoint

- Makes the workspace conn cache concurrency-safe
- Reduces unnecessary open checks in `peer.Channel`
- Fixes the use of a temporary context when dialing a workspace agent

* Add embed errors

* chore: Refactor site to improve testing

It was difficult to develop this package due to the
embed build tag being mandatory on the tests. The logic
to test doesn't require any embedded files.

* Add test for error handler

* Remove unused access url

* Add RBAC tests

* Fix dial agent syntax

* Fix linting errors

* Fix gen

* Fix icon required

* Adjust migration number

* Fix proxy error status code

* Fix empty db lookup
This commit is contained in:
Kyle Carberry
2022-06-04 15:13:37 -05:00
committed by GitHub
parent 2c089d5a99
commit 013f028e55
42 changed files with 1710 additions and 268 deletions

View File

@ -27,6 +27,7 @@ import (
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
)
@ -44,6 +45,7 @@ type Options struct {
// app. Specific routes may have their own limiters.
APIRateLimit int
AWSCertificates awsidentity.Certificates
Authorizer rbac.Authorizer
AzureCertificates x509.VerifyOptions
GoogleTokenValidator *idtoken.Validator
GithubOAuth2Config *GithubOAuth2Config
@ -51,7 +53,6 @@ type Options struct {
SecureAuthCookie bool
SSHKeygenAlgorithm gitsshkey.Algorithm
TURNServer *turnconn.Server
Authorizer rbac.Authorizer
TracerProvider *sdktrace.TracerProvider
}
@ -75,9 +76,11 @@ func New(options *Options) *API {
r := chi.NewRouter()
api := &API{
Options: options,
Handler: r,
Options: options,
Handler: r,
siteHandler: site.Handler(site.FS()),
}
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
@ -93,6 +96,20 @@ func New(options *Options) *API {
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
)
apps := func(r chi.Router) {
r.Use(
httpmw.RateLimitPerMinute(options.APIRateLimit),
apiKeyMiddleware,
httpmw.ExtractUserParam(api.Database),
)
r.Get("/*", api.workspaceAppsProxyPath)
}
// %40 is the encoded character of the @ symbol. VS Code Web does
// not handle character encoding properly, so it's safe to assume
// other applications might not as well.
r.Route("/%40{user}/{workspacename}/apps/{workspaceapp}", apps)
r.Route("/@{user}/{workspacename}/apps/{workspaceapp}", apps)
r.Route("/api/v2", func(r chi.Router) {
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
@ -327,24 +344,27 @@ func New(options *Options) *API {
r.Get("/state", api.workspaceBuildState)
})
})
r.NotFound(site.Handler(site.FS()).ServeHTTP)
r.NotFound(api.siteHandler.ServeHTTP)
return api
}
type API struct {
*Options
Handler chi.Router
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
Handler chi.Router
siteHandler http.Handler
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
workspaceAgentCache *wsconncache.Cache
}
// Close waits for all WebSocket connections to drain before returning.
func (api *API) Close() {
func (api *API) Close() error {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Wait()
api.websocketWaitMutex.Unlock()
return api.workspaceAgentCache.Close()
}
func debugLogRequest(log slog.Logger) func(http.Handler) http.Handler {

View File

@ -74,6 +74,10 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
Agents: []*proto.Agent{{
Id: "something",
Auth: &proto.Agent_Token{},
Apps: []*proto.App{{
Name: "app",
Url: "http://localhost:3000",
}},
}},
}},
},
@ -128,6 +132,15 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
"GET:/api/v2/users/authmethods": {NoAuthorize: true},
"POST:/api/v2/csp/reports": {NoAuthorize: true},
"GET:/%40{user}/{workspacename}/apps/{application}/*": {
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
"GET:/@{user}/{workspacename}/apps/{application}/*": {
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
// Has it's own auth
"GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true},
@ -368,6 +381,7 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
route = strings.ReplaceAll(route, "{template}", template.ID.String())
route = strings.ReplaceAll(route, "{hash}", file.Hash)
route = strings.ReplaceAll(route, "{workspaceresource}", workspaceResources[0].ID.String())
route = strings.ReplaceAll(route, "{workspaceapp}", workspaceResources[0].Agents[0].Apps[0].Name)
route = strings.ReplaceAll(route, "{templateversion}", version.ID.String())
route = strings.ReplaceAll(route, "{templateversiondryrun}", templateVersionDryRun.ID.String())
route = strings.ReplaceAll(route, "{templatename}", template.Name)

View File

@ -173,7 +173,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, *coderd.API)
cancelFunc()
_ = turnServer.Close()
srv.Close()
coderAPI.Close()
_ = coderAPI.Close()
})
return codersdk.New(serverURL), coderAPI

View File

@ -35,6 +35,7 @@ func New() database.Store {
templateVersions: make([]database.TemplateVersion, 0),
templates: make([]database.Template, 0),
workspaceBuilds: make([]database.WorkspaceBuild, 0),
workspaceApps: make([]database.WorkspaceApp, 0),
workspaces: make([]database.Workspace, 0),
}
}
@ -63,6 +64,7 @@ type fakeQuerier struct {
templateVersions []database.TemplateVersion
templates []database.Template
workspaceBuilds []database.WorkspaceBuild
workspaceApps []database.WorkspaceApp
workspaces []database.Workspace
}
@ -388,6 +390,38 @@ func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg databa
return database.Workspace{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAppsByAgentID(_ context.Context, id uuid.UUID) ([]database.WorkspaceApp, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
apps := make([]database.WorkspaceApp, 0)
for _, app := range q.workspaceApps {
if app.AgentID == id {
apps = append(apps, app)
}
}
if len(apps) == 0 {
return nil, sql.ErrNoRows
}
return apps, nil
}
func (q *fakeQuerier) GetWorkspaceAppsByAgentIDs(_ context.Context, ids []uuid.UUID) ([]database.WorkspaceApp, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
apps := make([]database.WorkspaceApp, 0)
for _, app := range q.workspaceApps {
for _, id := range ids {
if app.AgentID.String() == id.String() {
apps = append(apps, app)
break
}
}
}
return apps, nil
}
func (q *fakeQuerier) GetWorkspacesAutostart(_ context.Context) ([]database.Workspace, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1031,6 +1065,22 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc
return workspaceAgents, nil
}
func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndNameParams) (database.WorkspaceApp, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, app := range q.workspaceApps {
if app.AgentID != arg.AgentID {
continue
}
if app.Name != arg.Name {
continue
}
return app, nil
}
return database.WorkspaceApp{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) (database.ProvisionerDaemon, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1521,6 +1571,25 @@ func (q *fakeQuerier) InsertWorkspaceBuild(_ context.Context, arg database.Inser
return workspaceBuild, nil
}
func (q *fakeQuerier) InsertWorkspaceApp(_ context.Context, arg database.InsertWorkspaceAppParams) (database.WorkspaceApp, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
// nolint:gosimple
workspaceApp := database.WorkspaceApp{
ID: arg.ID,
AgentID: arg.AgentID,
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Icon: arg.Icon,
Command: arg.Command,
Url: arg.Url,
RelativePath: arg.RelativePath,
}
q.workspaceApps = append(q.workspaceApps, workspaceApp)
return workspaceApp, nil
}
func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -280,6 +280,17 @@ CREATE TABLE workspace_agents (
directory character varying(4096) DEFAULT ''::character varying NOT NULL
);
CREATE TABLE workspace_apps (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
agent_id uuid NOT NULL,
name character varying(64) NOT NULL,
icon character varying(256) NOT NULL,
command character varying(65534),
url character varying(65534),
relative_path boolean DEFAULT false NOT NULL
);
CREATE TABLE workspace_builds (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@ -382,6 +393,12 @@ ALTER TABLE ONLY users
ALTER TABLE ONLY workspace_agents
ADD CONSTRAINT workspace_agents_pkey PRIMARY KEY (id);
ALTER TABLE ONLY workspace_apps
ADD CONSTRAINT workspace_apps_agent_id_name_key UNIQUE (agent_id, name);
ALTER TABLE ONLY workspace_apps
ADD CONSTRAINT workspace_apps_pkey PRIMARY KEY (id);
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_job_id_key UNIQUE (job_id);
@ -463,6 +480,9 @@ ALTER TABLE ONLY templates
ALTER TABLE ONLY workspace_agents
ADD CONSTRAINT workspace_agents_resource_id_fkey FOREIGN KEY (resource_id) REFERENCES workspace_resources(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_apps
ADD CONSTRAINT workspace_apps_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
ALTER TABLE ONLY workspace_builds
ADD CONSTRAINT workspace_builds_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;

View File

@ -0,0 +1 @@
DROP TABLE workspace_apps;

View File

@ -0,0 +1,12 @@
CREATE TABLE workspace_apps (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
agent_id uuid NOT NULL REFERENCES workspace_agents (id) ON DELETE CASCADE,
name varchar(64) NOT NULL,
icon varchar(256) NOT NULL,
command varchar(65534),
url varchar(65534),
relative_path boolean NOT NULL DEFAULT false,
PRIMARY KEY (id),
UNIQUE(agent_id, name)
);

View File

@ -493,6 +493,17 @@ type WorkspaceAgent struct {
Directory string `db:"directory" json:"directory"`
}
type WorkspaceApp struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
Name string `db:"name" json:"name"`
Icon string `db:"icon" json:"icon"`
Command sql.NullString `db:"command" json:"command"`
Url sql.NullString `db:"url" json:"url"`
RelativePath bool `db:"relative_path" json:"relative_path"`
}
type WorkspaceBuild struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`

View File

@ -64,6 +64,9 @@ type querier interface {
GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error)
GetWorkspaceAppByAgentIDAndName(ctx context.Context, arg GetWorkspaceAppByAgentIDAndNameParams) (WorkspaceApp, error)
GetWorkspaceAppsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceApp, error)
GetWorkspaceAppsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceApp, error)
GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByWorkspaceID(ctx context.Context, arg GetWorkspaceBuildByWorkspaceIDParams) ([]WorkspaceBuild, error)
@ -93,6 +96,7 @@ type querier interface {
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
InsertWorkspace(ctx context.Context, arg InsertWorkspaceParams) (Workspace, error)
InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error)
InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error)
InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error)
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error

View File

@ -2785,6 +2785,155 @@ func (q *sqlQuerier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg
return err
}
const getWorkspaceAppByAgentIDAndName = `-- name: GetWorkspaceAppByAgentIDAndName :one
SELECT id, created_at, agent_id, name, icon, command, url, relative_path FROM workspace_apps WHERE agent_id = $1 AND name = $2
`
type GetWorkspaceAppByAgentIDAndNameParams struct {
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
Name string `db:"name" json:"name"`
}
func (q *sqlQuerier) GetWorkspaceAppByAgentIDAndName(ctx context.Context, arg GetWorkspaceAppByAgentIDAndNameParams) (WorkspaceApp, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceAppByAgentIDAndName, arg.AgentID, arg.Name)
var i WorkspaceApp
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.AgentID,
&i.Name,
&i.Icon,
&i.Command,
&i.Url,
&i.RelativePath,
)
return i, err
}
const getWorkspaceAppsByAgentID = `-- name: GetWorkspaceAppsByAgentID :many
SELECT id, created_at, agent_id, name, icon, command, url, relative_path FROM workspace_apps WHERE agent_id = $1
`
func (q *sqlQuerier) GetWorkspaceAppsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceApp, error) {
rows, err := q.db.QueryContext(ctx, getWorkspaceAppsByAgentID, agentID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkspaceApp
for rows.Next() {
var i WorkspaceApp
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.AgentID,
&i.Name,
&i.Icon,
&i.Command,
&i.Url,
&i.RelativePath,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getWorkspaceAppsByAgentIDs = `-- name: GetWorkspaceAppsByAgentIDs :many
SELECT id, created_at, agent_id, name, icon, command, url, relative_path FROM workspace_apps WHERE agent_id = ANY($1 :: uuid [ ])
`
func (q *sqlQuerier) GetWorkspaceAppsByAgentIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceApp, error) {
rows, err := q.db.QueryContext(ctx, getWorkspaceAppsByAgentIDs, pq.Array(ids))
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkspaceApp
for rows.Next() {
var i WorkspaceApp
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.AgentID,
&i.Name,
&i.Icon,
&i.Command,
&i.Url,
&i.RelativePath,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertWorkspaceApp = `-- name: InsertWorkspaceApp :one
INSERT INTO
workspace_apps (
id,
created_at,
agent_id,
name,
icon,
command,
url,
relative_path
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, agent_id, name, icon, command, url, relative_path
`
type InsertWorkspaceAppParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
Name string `db:"name" json:"name"`
Icon string `db:"icon" json:"icon"`
Command sql.NullString `db:"command" json:"command"`
Url sql.NullString `db:"url" json:"url"`
RelativePath bool `db:"relative_path" json:"relative_path"`
}
func (q *sqlQuerier) InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error) {
row := q.db.QueryRowContext(ctx, insertWorkspaceApp,
arg.ID,
arg.CreatedAt,
arg.AgentID,
arg.Name,
arg.Icon,
arg.Command,
arg.Url,
arg.RelativePath,
)
var i WorkspaceApp
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.AgentID,
&i.Name,
&i.Icon,
&i.Command,
&i.Url,
&i.RelativePath,
)
return i, err
}
const getLatestWorkspaceBuildByWorkspaceID = `-- name: GetLatestWorkspaceBuildByWorkspaceID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, name, build_number, transition, initiator_id, provisioner_state, job_id, deadline

View File

@ -0,0 +1,23 @@
-- name: GetWorkspaceAppsByAgentID :many
SELECT * FROM workspace_apps WHERE agent_id = $1;
-- name: GetWorkspaceAppsByAgentIDs :many
SELECT * FROM workspace_apps WHERE agent_id = ANY(@ids :: uuid [ ]);
-- name: GetWorkspaceAppByAgentIDAndName :one
SELECT * FROM workspace_apps WHERE agent_id = $1 AND name = $2;
-- name: InsertWorkspaceApp :one
INSERT INTO
workspace_apps (
id,
created_at,
agent_id,
name,
icon,
command,
url,
relative_path
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;

View File

@ -17,7 +17,7 @@ import (
"github.com/coder/coder/coderd/httpapi"
)
// SessionTokenKey represents the name of the cookie or query paramater the API key is stored in.
// SessionTokenKey represents the name of the cookie or query parameter the API key is stored in.
const SessionTokenKey = "session_token"
type apiKeyContextKey struct{}

View File

@ -724,7 +724,7 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
}
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
dbAgent, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
ID: uuid.New(),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
@ -744,6 +744,28 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
if err != nil {
return xerrors.Errorf("insert agent: %w", err)
}
for _, app := range agent.Apps {
_, err := db.InsertWorkspaceApp(ctx, database.InsertWorkspaceAppParams{
ID: uuid.New(),
CreatedAt: database.Now(),
AgentID: dbAgent.ID,
Name: app.Name,
Icon: app.Icon,
Command: sql.NullString{
String: app.Command,
Valid: app.Command != "",
},
Url: sql.NullString{
String: app.Url,
Valid: app.Url != "",
},
RelativePath: app.RelativePath,
})
if err != nil {
return xerrors.Errorf("insert app: %w", err)
}
}
}
return nil
}

View File

@ -220,6 +220,20 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
})
return
}
resourceAgentIDs := make([]uuid.UUID, 0)
for _, agent := range resourceAgents {
resourceAgentIDs = append(resourceAgentIDs, agent.ID)
}
apps, err := api.Database.GetWorkspaceAppsByAgentIDs(r.Context(), resourceAgentIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace apps: %s", err),
})
return
}
apiResources := make([]codersdk.WorkspaceResource, 0)
for _, resource := range resources {
@ -228,7 +242,14 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
if agent.ResourceID != resource.ID {
continue
}
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
dbApps := make([]database.WorkspaceApp, 0)
for _, app := range apps {
if app.AgentID == agent.ID {
dbApps = append(dbApps, app)
}
}
apiAgent, err := convertWorkspaceAgent(agent, convertApps(dbApps), api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error reading job agent",

View File

@ -51,7 +51,7 @@ This can be represented by the following truth table, where Y represents *positi
## Example Permissions
- `+site.*.*.read`: allowed to perform the `read` action against all objects of type `devurl` in a given Coder deployment.
- `+site.*.*.read`: allowed to perform the `read` action against all objects of type `app` in a given Coder deployment.
- `-user.workspace.*.create`: user is not allowed to create workspaces.
## Roles

View File

@ -113,7 +113,7 @@ type Object struct {
// OrgID specifies which org the object is a part of.
OrgID string `json:"org_owner"`
// Type is "workspace", "project", "devurl", etc
// Type is "workspace", "project", "app", etc
Type string `json:"type"`
// TODO: SharedUsers?
}

View File

@ -31,7 +31,14 @@ import (
func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
dbApps, err := api.Database.GetWorkspaceAppsByAgentID(r.Context(), workspaceAgent.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace agent apps: %s", err),
})
return
}
apiAgent, err := convertWorkspaceAgent(workspaceAgent, convertApps(dbApps), api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error reading workspace agent",
@ -50,7 +57,7 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error reading workspace agent",
@ -97,7 +104,7 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error reading workspace agent",
@ -358,7 +365,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error reading workspace agent",
@ -403,16 +410,16 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
return
}
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
_, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
agentConn, err := api.dialWorkspaceAgent(ctx, r, workspaceAgent.ID)
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer agentConn.Close()
ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), "")
defer release()
ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
@ -428,8 +435,9 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
// r.Context() for cancellation if it's use is safe or r.Hijack() has
// not been performed.
func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
client, server := provisionersdk.TransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
ChannelID: agentID.String(),
@ -443,9 +451,12 @@ func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := peerClient.NegotiateConnection(ctx)
if err != nil {
cancelFunc()
return nil, xerrors.Errorf("negotiate: %w", err)
}
options := &peer.ConnOptions{}
options := &peer.ConnOptions{
Logger: api.Logger.Named("agent-dialer"),
}
options.SettingEngine.SetSrflxAcceptanceMinWait(0)
options.SettingEngine.SetRelayAcceptanceMinWait(0)
// Use the ProxyDialer for the TURN server.
@ -476,15 +487,33 @@ func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID
}))
peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options)
if err != nil {
cancelFunc()
return nil, xerrors.Errorf("dial: %w", err)
}
go func() {
<-peerConn.Closed()
cancelFunc()
}()
return &agent.Conn{
Negotiator: peerClient,
Conn: peerConn,
}, nil
}
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
apps := make([]codersdk.WorkspaceApp, 0)
for _, dbApp := range dbApps {
apps = append(apps, codersdk.WorkspaceApp{
ID: dbApp.ID,
Name: dbApp.Name,
Command: dbApp.Command.String,
Icon: dbApp.Icon,
})
}
return apps
}
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
@ -504,6 +533,7 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
StartupScript: dbAgent.StartupScript.String,
EnvironmentVariables: envs,
Directory: dbAgent.Directory,
Apps: apps,
}
if dbAgent.FirstConnectedAt.Valid {
workspaceAgent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time

View File

@ -297,7 +297,7 @@ func TestWorkspaceAgentPTY(t *testing.T) {
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
conn, err := client.WorkspaceAgentReconnectingPTY(context.Background(), resources[0].Agents[0].ID, uuid.New(), 80, 80)
conn, err := client.WorkspaceAgentReconnectingPTY(context.Background(), resources[0].Agents[0].ID, uuid.New(), 80, 80, "/bin/bash")
require.NoError(t, err)
defer conn.Close()

166
coderd/workspaceapps.go Normal file
View File

@ -0,0 +1,166 @@
package coderd
import (
"database/sql"
"errors"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/site"
)
// workspaceAppsProxyPath proxies requests to a workspace application
// through a relative URL path.
func (api *API) workspaceAppsProxyPath(rw http.ResponseWriter, r *http.Request) {
user := httpmw.UserParam(r)
// This can be in the form of: "<workspace-name>.[workspace-agent]" or "<workspace-name>"
workspaceWithAgent := chi.URLParam(r, "workspacename")
workspaceParts := strings.Split(workspaceWithAgent, ".")
workspace, err := api.Database.GetWorkspaceByOwnerIDAndName(r.Context(), database.GetWorkspaceByOwnerIDAndNameParams{
OwnerID: user.ID,
Name: workspaceParts[0],
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: "workspace not found",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace: %s", err),
})
return
}
if !api.Authorize(rw, r, rbac.ActionRead, workspace) {
return
}
build, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), workspace.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build: %s", err),
})
return
}
resources, err := api.Database.GetWorkspaceResourcesByJobID(r.Context(), build.JobID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace resources: %s", err),
})
return
}
resourceIDs := make([]uuid.UUID, 0)
for _, resource := range resources {
resourceIDs = append(resourceIDs, resource.ID)
}
agents, err := api.Database.GetWorkspaceAgentsByResourceIDs(r.Context(), resourceIDs)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace agents: %s", err),
})
return
}
if len(agents) == 0 {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "no agents exist",
})
}
agent := agents[0]
if len(workspaceParts) > 1 {
for _, otherAgent := range agents {
if otherAgent.Name == workspaceParts[1] {
agent = otherAgent
break
}
}
}
app, err := api.Database.GetWorkspaceAppByAgentIDAndName(r.Context(), database.GetWorkspaceAppByAgentIDAndNameParams{
AgentID: agent.ID,
Name: chi.URLParam(r, "workspaceapp"),
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: "application not found",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace app: %s", err),
})
return
}
if !app.Url.Valid {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("application does not have a url: %s", err),
})
return
}
appURL, err := url.Parse(app.Url.String)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("parse app url: %s", err),
})
return
}
proxy := httputil.NewSingleHostReverseProxy(appURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
// This is a browser-facing route so JSON responses are not viable here.
// To pass friendly errors to the frontend, special meta tags are overridden
// in the index.html with the content passed here.
r = r.WithContext(site.WithAPIResponse(r.Context(), site.APIResponse{
StatusCode: http.StatusBadGateway,
Message: err.Error(),
}))
api.siteHandler.ServeHTTP(w, r)
}
path := chi.URLParam(r, "*")
if !strings.HasSuffix(r.URL.Path, "/") && path == "" {
// Web applications typically request paths relative to the
// root URL. This allows for routing behind a proxy or subpath.
// See https://github.com/coder/code-server/issues/241 for examples.
r.URL.Path += "/"
http.Redirect(rw, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
if r.URL.RawQuery == "" && appURL.RawQuery != "" {
// If the application defines a default set of query parameters,
// we should always respect them. The reverse proxy will merge
// query parameters for server-side requests, but sometimes
// client-side applications require the query parameters to render
// properly. With code-server, this is the "folder" param.
r.URL.RawQuery = appURL.RawQuery
http.Redirect(rw, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
r.URL.Path = path
conn, release, err := api.workspaceAgentCache.Acquire(r, agent.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("dial workspace agent: %s", err),
})
return
}
defer release()
proxy.Transport = conn.HTTPTransport()
proxy.ServeHTTP(rw, r)
}

View File

@ -0,0 +1,125 @@
package coderd_test
import (
"context"
"fmt"
"io"
"net"
"net/http"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)
func TestWorkspaceAppsProxyPath(t *testing.T) {
t.Parallel()
// #nosec
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
server := http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
}
t.Cleanup(func() {
_ = server.Close()
_ = ln.Close()
})
go server.Serve(ln)
tcpAddr, _ := ln.Addr().(*net.TCPAddr)
client, coderAPI := coderdtest.NewWithAPI(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdtest.NewProvisionerDaemon(t, coderAPI)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
Apps: []*proto.App{{
Name: "example",
Url: fmt.Sprintf("http://127.0.0.1:%d?query=true", tcpAddr.Port),
}, {
Name: "fake",
Url: "http://127.0.0.2",
}},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
t.Run("RedirectsWithSlash", func(t *testing.T) {
t.Parallel()
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
t.Run("RedirectsWithQuery", func(t *testing.T) {
t.Parallel()
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example/", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
loc, err := resp.Location()
require.NoError(t, err)
require.Equal(t, "query=true", loc.RawQuery)
})
t.Run("Proxies", func(t *testing.T) {
t.Parallel()
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example/?query=true", nil)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "", string(body))
require.Equal(t, http.StatusOK, resp.StatusCode)
})
t.Run("ProxyError", func(t *testing.T) {
t.Parallel()
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/fake/", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
})
}

View File

@ -3,10 +3,12 @@ package coderd
import (
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
@ -46,9 +48,27 @@ func (api *API) workspaceResource(rw http.ResponseWriter, r *http.Request) {
})
return
}
agentIDs := make([]uuid.UUID, 0)
for _, agent := range agents {
agentIDs = append(agentIDs, agent.ID)
}
apps, err := api.Database.GetWorkspaceAppsByAgentIDs(r.Context(), agentIDs)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace apps: %s", err),
})
return
}
apiAgents := make([]codersdk.WorkspaceAgent, 0)
for _, agent := range agents {
convertedAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
dbApps := make([]database.WorkspaceApp, 0)
for _, app := range apps {
if app.AgentID == agent.ID {
dbApps = append(dbApps, app)
}
}
convertedAgent, err := convertWorkspaceAgent(agent, convertApps(dbApps), api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error reading workspace agent",

View File

@ -43,4 +43,50 @@ func TestWorkspaceResource(t *testing.T) {
_, err = client.WorkspaceResource(context.Background(), resources[0].ID)
require.NoError(t, err)
})
t.Run("Apps", func(t *testing.T) {
t.Parallel()
client, coderd := coderdtest.NewWithAPI(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdtest.NewProvisionerDaemon(t, coderd)
app := &proto.App{
Name: "code-server",
Command: "some-command",
Url: "http://localhost:3000",
Icon: "/code.svg",
}
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agents: []*proto.Agent{{
Id: "something",
Auth: &proto.Agent_Token{},
Apps: []*proto.App{app},
}},
}},
},
},
}},
})
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
require.NoError(t, err)
resource, err := client.WorkspaceResource(context.Background(), resources[0].ID)
require.NoError(t, err)
require.Len(t, resource.Agents, 1)
agent := resource.Agents[0]
require.Len(t, agent.Apps, 1)
got := agent.Apps[0]
require.Equal(t, app.Command, got.Command)
require.Equal(t, app.Icon, got.Icon)
require.Equal(t, app.Name, got.Name)
})
}

View File

@ -0,0 +1,162 @@
// Package wsconncache caches workspace agent connections by UUID.
package wsconncache
import (
"context"
"net/http"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/atomic"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
"github.com/coder/coder/agent"
)
// New creates a new workspace connection cache that closes
// connections after the inactive timeout provided.
//
// Agent connections are cached due to WebRTC negotiation
// taking a few hundred milliseconds.
func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
if inactiveTimeout == 0 {
inactiveTimeout = 5 * time.Minute
}
return &Cache{
closed: make(chan struct{}),
dialer: dialer,
inactiveTimeout: inactiveTimeout,
}
}
// Dialer creates a new agent connection by ID.
type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error)
// Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct {
*agent.Conn
locks atomic.Uint64
timeoutMutex sync.Mutex
timeout *time.Timer
timeoutCancel context.CancelFunc
transport *http.Transport
}
func (c *Conn) HTTPTransport() *http.Transport {
return c.transport
}
// CloseWithError ends the HTTP transport if exists, and closes the agent.
func (c *Conn) CloseWithError(err error) error {
if c.transport != nil {
c.transport.CloseIdleConnections()
}
c.timeoutMutex.Lock()
defer c.timeoutMutex.Unlock()
if c.timeout != nil {
c.timeout.Stop()
}
return c.Conn.CloseWithError(err)
}
type Cache struct {
closed chan struct{}
closeMutex sync.Mutex
closeGroup sync.WaitGroup
connGroup singleflight.Group
connMap sync.Map
dialer Dialer
inactiveTimeout time.Duration
}
// Acquire gets or establishes a connection with the dialer using the ID provided.
// If a connection is in-progress, that connection or error will be returned.
//
// The returned function is used to release a lock on the connection. Once zero
// locks exist on a connection, the inactive timeout will begin to tick down.
// After the time expires, the connection will be cleared from the cache.
func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
rawConn, found := c.connMap.Load(id.String())
// If the connection isn't found, establish a new one!
if !found {
var err error
// A singleflight group is used to allow for concurrent requests to the
// same identifier to resolve.
rawConn, err, _ = c.connGroup.Do(id.String(), func() (interface{}, error) {
agentConn, err := c.dialer(r, id)
if err != nil {
return nil, xerrors.Errorf("dial: %w", err)
}
timeoutCtx, timeoutCancelFunc := context.WithCancel(context.Background())
defaultTransport, valid := http.DefaultTransport.(*http.Transport)
if !valid {
panic("dev error: default transport is the wrong type")
}
transport := defaultTransport.Clone()
transport.DialContext = agentConn.DialContext
conn := &Conn{
Conn: agentConn,
timeoutCancel: timeoutCancelFunc,
transport: transport,
}
c.closeMutex.Lock()
c.closeGroup.Add(1)
c.closeMutex.Unlock()
go func() {
defer c.closeGroup.Done()
var err error
select {
case <-timeoutCtx.Done():
err = xerrors.New("cache timeout")
case <-c.closed:
err = xerrors.New("cache closed")
case <-conn.Closed():
}
c.connMap.Delete(id.String())
c.connGroup.Forget(id.String())
_ = conn.CloseWithError(err)
}()
return conn, nil
})
if err != nil {
return nil, nil, err
}
c.connMap.Store(id.String(), rawConn)
}
conn, _ := rawConn.(*Conn)
conn.timeoutMutex.Lock()
defer conn.timeoutMutex.Unlock()
if conn.timeout != nil {
conn.timeout.Stop()
}
conn.locks.Inc()
return conn, func() {
conn.timeoutMutex.Lock()
defer conn.timeoutMutex.Unlock()
if conn.timeout != nil {
conn.timeout.Stop()
}
conn.locks.Dec()
if conn.locks.Load() == 0 {
conn.timeout = time.AfterFunc(c.inactiveTimeout, conn.timeoutCancel)
}
}, nil
}
func (c *Cache) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return nil
default:
}
close(c.closed)
c.closeGroup.Wait()
return nil
}

View File

@ -0,0 +1,175 @@
package wsconncache_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, 0)
t.Cleanup(func() {
_ = cache.Close()
})
conn1, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
conn2, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
require.True(t, conn1 == conn2)
})
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
called.Add(1)
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
t.Cleanup(func() {
_ = cache.Close()
})
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
conn, release, err = cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
require.Equal(t, int32(2), called.Load())
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
t.Cleanup(func() {
_ = cache.Close()
})
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
require.NoError(t, err)
time.Sleep(time.Millisecond)
release()
<-conn.Closed()
})
t.Run("HTTPTransport", func(t *testing.T) {
t.Parallel()
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
_ = random.Close()
})
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
}
t.Cleanup(func() {
_ = server.Close()
})
go server.Serve(random)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
t.Cleanup(func() {
_ = cache.Close()
})
var wg sync.WaitGroup
// Perform many requests in parallel to simulate
// simultaneous HTTP requests.
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port),
Path: "/",
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
conn, release, err := cache.Acquire(req, uuid.Nil)
if !assert.NoError(t, err) {
return
}
defer release()
proxy.Transport = conn.HTTPTransport()
res := httptest.NewRecorder()
proxy.ServeHTTP(res, req)
res.Result().Body.Close()
require.Equal(t, http.StatusOK, res.Result().StatusCode)
}()
}
wg.Wait()
})
}
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
client, server := provisionersdk.TransportPipe()
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
return nil, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
}, nil
})
return metadata, listener, err
}, &agent.Options{
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
ReconnectingPTYTimeout: ptyTimeout,
})
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
_ = closer.Close()
})
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(context.Background())
assert.NoError(t, err)
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
return &agent.Conn{
Negotiator: api,
Conn: conn,
}
}