feat: Allow workspace resources to attach multiple agents (#942)

This enables a "kubernetes_pod" to attach multiple agents that
could be for multiple services. Each agent is required to have
a unique name, so SSH syntax is:

`coder ssh <workspace>.<agent>`

A resource can have zero agents too, they aren't required.
This commit is contained in:
Kyle Carberry
2022-04-11 16:06:15 -05:00
committed by GitHub
parent 2835bb45e5
commit 19b4323512
47 changed files with 1550 additions and 1040 deletions

View File

@ -168,26 +168,31 @@ func New(options *Options) (http.Handler, func()) {
})
})
})
r.Route("/workspaceresources", func(r chi.Router) {
r.Route("/auth", func(r chi.Router) {
r.Post("/aws-instance-identity", api.postWorkspaceAuthAWSInstanceIdentity)
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
})
r.Route("/agent", func(r chi.Router) {
r.Route("/workspaceagents", func(r chi.Router) {
r.Post("/aws-instance-identity", api.postWorkspaceAuthAWSInstanceIdentity)
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
r.Route("/me", func(r chi.Router) {
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
r.Get("/", api.workspaceAgentListen)
r.Get("/gitsshkey", api.agentGitSSHKey)
})
r.Route("/{workspaceresource}", func(r chi.Router) {
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
httpmw.ExtractWorkspaceResourceParam(options.Database),
httpmw.ExtractWorkspaceParam(options.Database),
httpmw.ExtractWorkspaceAgentParam(options.Database),
)
r.Get("/", api.workspaceResource)
r.Get("/dial", api.workspaceResourceDial)
r.Get("/", api.workspaceAgent)
r.Get("/dial", api.workspaceAgentDial)
})
})
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),
httpmw.ExtractWorkspaceResourceParam(options.Database),
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceResource)
})
r.Route("/workspaces/{workspace}", func(r chi.Router) {
r.Use(
httpmw.ExtractAPIKey(options.Database, nil),

View File

@ -263,11 +263,10 @@ func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, build uuid.UUID
resources, err = client.WorkspaceResourcesByBuild(context.Background(), build)
require.NoError(t, err)
for _, resource := range resources {
if resource.Agent == nil {
continue
}
if resource.Agent.FirstConnectedAt == nil {
return false
for _, agent := range resource.Agents {
if agent.FirstConnectedAt == nil {
return false
}
}
}
return true

View File

@ -634,6 +634,20 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
return database.WorkspaceAgent{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgent) - 1; i >= 0; i-- {
agent := q.provisionerJobAgent[i]
if agent.ID.String() == id.String() {
return agent, nil
}
}
return database.WorkspaceAgent{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -648,16 +662,23 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI
return database.WorkspaceAgent{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetWorkspaceAgentByResourceID(_ context.Context, resourceID uuid.UUID) (database.WorkspaceAgent, error) {
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
workspaceAgents := make([]database.WorkspaceAgent, 0)
for _, agent := range q.provisionerJobAgent {
if agent.ResourceID.String() == resourceID.String() {
return agent, nil
for _, resourceID := range resourceIDs {
if agent.ResourceID.String() != resourceID.String() {
continue
}
workspaceAgents = append(workspaceAgents, agent)
}
}
return database.WorkspaceAgent{}, sql.ErrNoRows
if len(workspaceAgents) == 0 {
return nil, sql.ErrNoRows
}
return workspaceAgents, nil
}
func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) (database.ProvisionerDaemon, error) {
@ -982,6 +1003,9 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser
AuthToken: arg.AuthToken,
AuthInstanceID: arg.AuthInstanceID,
EnvironmentVariables: arg.EnvironmentVariables,
Name: arg.Name,
Architecture: arg.Architecture,
OperatingSystem: arg.OperatingSystem,
StartupScript: arg.StartupScript,
InstanceMetadata: arg.InstanceMetadata,
ResourceMetadata: arg.ResourceMetadata,
@ -1003,7 +1027,6 @@ func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.In
Address: arg.Address,
Type: arg.Type,
Name: arg.Name,
AgentID: arg.AgentID,
}
q.provisionerJobResource = append(q.provisionerJobResource, resource)
return resource, nil

View File

@ -235,13 +235,16 @@ CREATE TABLE workspace_agents (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone NOT NULL,
name character varying(64) NOT NULL,
first_connected_at timestamp with time zone,
last_connected_at timestamp with time zone,
disconnected_at timestamp with time zone,
resource_id uuid NOT NULL,
auth_token uuid NOT NULL,
auth_instance_id character varying(64),
architecture character varying(64) NOT NULL,
environment_variables jsonb,
operating_system character varying(64) NOT NULL,
startup_script character varying(65534),
instance_metadata jsonb,
resource_metadata jsonb
@ -269,8 +272,7 @@ CREATE TABLE workspace_resources (
transition workspace_transition NOT NULL,
address character varying(256) NOT NULL,
type character varying(192) NOT NULL,
name character varying(64) NOT NULL,
agent_id uuid
name character varying(64) NOT NULL
);
CREATE TABLE workspaces (

View File

@ -69,7 +69,6 @@ CREATE TABLE workspace_resources (
address varchar(256) NOT NULL,
type varchar(192) NOT NULL,
name varchar(64) NOT NULL,
agent_id uuid,
PRIMARY KEY (id)
);
@ -77,13 +76,16 @@ CREATE TABLE workspace_agents (
id uuid NOT NULL,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
name varchar(64) NOT NULL,
first_connected_at timestamptz,
last_connected_at timestamptz,
disconnected_at timestamptz,
resource_id uuid NOT NULL REFERENCES workspace_resources (id) ON DELETE CASCADE,
auth_token uuid NOT NULL UNIQUE,
auth_instance_id varchar(64),
architecture varchar(64) NOT NULL,
environment_variables jsonb,
operating_system varchar(64) NOT NULL,
startup_script varchar(65534),
instance_metadata jsonb,
resource_metadata jsonb,

View File

@ -403,13 +403,16 @@ type WorkspaceAgent struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
@ -438,5 +441,4 @@ type WorkspaceResource struct {
Address string `db:"address" json:"address"`
Type string `db:"type" json:"type"`
Name string `db:"name" json:"name"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
}

View File

@ -39,8 +39,9 @@ type querier interface {
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserCount(ctx context.Context) (int64, error)
GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
GetWorkspaceAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (WorkspaceAgent, error)
GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error)
GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (WorkspaceBuild, error)
GetWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceBuild, error)

View File

@ -1907,7 +1907,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one
SELECT
id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
FROM
workspace_agents
WHERE
@ -1923,13 +1923,49 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.StartupScript,
&i.InstanceMetadata,
&i.ResourceMetadata,
)
return i, err
}
const getWorkspaceAgentByID = `-- name: GetWorkspaceAgentByID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
FROM
workspace_agents
WHERE
id = $1
`
func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceAgentByID, id)
var i WorkspaceAgent
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.StartupScript,
&i.InstanceMetadata,
&i.ResourceMetadata,
@ -1939,7 +1975,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one
SELECT
id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
FROM
workspace_agents
WHERE
@ -1955,13 +1991,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.StartupScript,
&i.InstanceMetadata,
&i.ResourceMetadata,
@ -1969,34 +2008,53 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
return i, err
}
const getWorkspaceAgentByResourceID = `-- name: GetWorkspaceAgentByResourceID :one
const getWorkspaceAgentsByResourceIDs = `-- name: GetWorkspaceAgentsByResourceIDs :many
SELECT
id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
FROM
workspace_agents
WHERE
resource_id = $1
resource_id = ANY($1 :: uuid [ ])
`
func (q *sqlQuerier) GetWorkspaceAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (WorkspaceAgent, error) {
row := q.db.QueryRowContext(ctx, getWorkspaceAgentByResourceID, resourceID)
var i WorkspaceAgent
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.EnvironmentVariables,
&i.StartupScript,
&i.InstanceMetadata,
&i.ResourceMetadata,
)
return i, err
func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error) {
rows, err := q.db.QueryContext(ctx, getWorkspaceAgentsByResourceIDs, pq.Array(ids))
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkspaceAgent
for rows.Next() {
var i WorkspaceAgent
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.StartupScript,
&i.InstanceMetadata,
&i.ResourceMetadata,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertWorkspaceAgent = `-- name: InsertWorkspaceAgent :one
@ -2005,26 +2063,32 @@ INSERT INTO
id,
created_at,
updated_at,
name,
resource_id,
auth_token,
auth_instance_id,
architecture,
environment_variables,
operating_system,
startup_script,
instance_metadata,
resource_metadata
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
`
type InsertWorkspaceAgentParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
@ -2035,10 +2099,13 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
arg.ID,
arg.CreatedAt,
arg.UpdatedAt,
arg.Name,
arg.ResourceID,
arg.AuthToken,
arg.AuthInstanceID,
arg.Architecture,
arg.EnvironmentVariables,
arg.OperatingSystem,
arg.StartupScript,
arg.InstanceMetadata,
arg.ResourceMetadata,
@ -2048,13 +2115,16 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Name,
&i.FirstConnectedAt,
&i.LastConnectedAt,
&i.DisconnectedAt,
&i.ResourceID,
&i.AuthToken,
&i.AuthInstanceID,
&i.Architecture,
&i.EnvironmentVariables,
&i.OperatingSystem,
&i.StartupScript,
&i.InstanceMetadata,
&i.ResourceMetadata,
@ -2405,7 +2475,7 @@ func (q *sqlQuerier) UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWor
const getWorkspaceResourceByID = `-- name: GetWorkspaceResourceByID :one
SELECT
id, created_at, job_id, transition, address, type, name, agent_id
id, created_at, job_id, transition, address, type, name
FROM
workspace_resources
WHERE
@ -2423,14 +2493,13 @@ func (q *sqlQuerier) GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID)
&i.Address,
&i.Type,
&i.Name,
&i.AgentID,
)
return i, err
}
const getWorkspaceResourcesByJobID = `-- name: GetWorkspaceResourcesByJobID :many
SELECT
id, created_at, job_id, transition, address, type, name, agent_id
id, created_at, job_id, transition, address, type, name
FROM
workspace_resources
WHERE
@ -2454,7 +2523,6 @@ func (q *sqlQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uui
&i.Address,
&i.Type,
&i.Name,
&i.AgentID,
); err != nil {
return nil, err
}
@ -2478,11 +2546,10 @@ INSERT INTO
transition,
address,
type,
name,
agent_id
name
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, job_id, transition, address, type, name, agent_id
($1, $2, $3, $4, $5, $6, $7) RETURNING id, created_at, job_id, transition, address, type, name
`
type InsertWorkspaceResourceParams struct {
@ -2493,7 +2560,6 @@ type InsertWorkspaceResourceParams struct {
Address string `db:"address" json:"address"`
Type string `db:"type" json:"type"`
Name string `db:"name" json:"name"`
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
}
func (q *sqlQuerier) InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) {
@ -2505,7 +2571,6 @@ func (q *sqlQuerier) InsertWorkspaceResource(ctx context.Context, arg InsertWork
arg.Address,
arg.Type,
arg.Name,
arg.AgentID,
)
var i WorkspaceResource
err := row.Scan(
@ -2516,7 +2581,6 @@ func (q *sqlQuerier) InsertWorkspaceResource(ctx context.Context, arg InsertWork
&i.Address,
&i.Type,
&i.Name,
&i.AgentID,
)
return i, err
}

View File

@ -8,6 +8,14 @@ WHERE
ORDER BY
created_at DESC;
-- name: GetWorkspaceAgentByID :one
SELECT
*
FROM
workspace_agents
WHERE
id = $1;
-- name: GetWorkspaceAgentByInstanceID :one
SELECT
*
@ -18,13 +26,13 @@ WHERE
ORDER BY
created_at DESC;
-- name: GetWorkspaceAgentByResourceID :one
-- name: GetWorkspaceAgentsByResourceIDs :many
SELECT
*
FROM
workspace_agents
WHERE
resource_id = $1;
resource_id = ANY(@ids :: uuid [ ]);
-- name: InsertWorkspaceAgent :one
INSERT INTO
@ -32,16 +40,19 @@ INSERT INTO
id,
created_at,
updated_at,
name,
resource_id,
auth_token,
auth_instance_id,
architecture,
environment_variables,
operating_system,
startup_script,
instance_metadata,
resource_metadata
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING *;
-- name: UpdateWorkspaceAgentConnectionByID :exec
UPDATE

View File

@ -23,8 +23,7 @@ INSERT INTO
transition,
address,
type,
name,
agent_id
name
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
($1, $2, $3, $4, $5, $6, $7) RETURNING *;

View File

@ -79,51 +79,40 @@ func TestGitSSHKey(t *testing.T) {
func TestAgentGitSSHKey(t *testing.T) {
t.Parallel()
agentClient := func(algo gitsshkey.Algorithm) *codersdk.Client {
client := coderdtest.New(t, &coderdtest.Options{
SSHKeygenAlgorithm: algo,
})
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agent: &proto.Agent{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
},
}},
},
}},
})
project := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, project.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
return agentClient
}
t.Run("AgentKey", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := agentClient(gitsshkey.AlgorithmEd25519)
agentKey, err := client.AgentGitSSHKey(ctx)
require.NoError(t, err)
require.NotEmpty(t, agentKey.PrivateKey)
},
}},
})
project := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, project.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentKey, err := agentClient.AgentGitSSHKey(context.Background())
require.NoError(t, err)
require.NotEmpty(t, agentKey.PrivateKey)
}

View File

@ -0,0 +1,94 @@
package httpmw
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
)
type workspaceAgentParamContextKey struct{}
// WorkspaceAgentParam returns the workspace agent from the ExtractWorkspaceAgentParam handler.
func WorkspaceAgentParam(r *http.Request) database.WorkspaceAgent {
user, ok := r.Context().Value(workspaceAgentParamContextKey{}).(database.WorkspaceAgent)
if !ok {
panic("developer error: agent middleware not provided")
}
return user
}
// ExtractWorkspaceAgentParam grabs a workspace agent from the "workspaceagent" URL parameter.
func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
agentUUID, parsed := parseUUID(rw, r, "workspaceagent")
if !parsed {
return
}
agent, err := db.GetWorkspaceAgentByID(r.Context(), agentUUID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
Message: "agent doesn't exist with that id",
})
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get agent: %s", err),
})
return
}
resource, err := db.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get resource: %s", err),
})
return
}
job, err := db.GetProvisionerJobByID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get job: %s", err),
})
return
}
if job.Type != database.ProvisionerJobTypeWorkspaceBuild {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Workspace agents can only be fetched for builds.",
})
return
}
build, err := db.GetWorkspaceBuildByJobID(r.Context(), job.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build: %s", err),
})
return
}
workspace, err := db.GetWorkspaceByID(r.Context(), build.WorkspaceID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace: %s", err),
})
return
}
apiKey := APIKey(r)
if apiKey.UserID != workspace.OwnerID {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "getting non-personal agents isn't supported",
})
return
}
ctx := context.WithValue(r.Context(), workspaceAgentParamContextKey{}, agent)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}

View File

@ -0,0 +1,153 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/cryptorand"
)
func TestWorkspaceAgentParam(t *testing.T) {
t.Parallel()
setupAuthentication := func(db database.Store) (*http.Request, database.WorkspaceAgent) {
var (
id, secret = randomAPIKeyParts()
hashed = sha256.Sum256([]byte(secret))
)
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&http.Cookie{
Name: httpmw.AuthCookie,
Value: fmt.Sprintf("%s-%s", id, secret),
})
userID := uuid.New()
username, err := cryptorand.String(8)
require.NoError(t, err)
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
ID: userID,
Email: "testaccount@coder.com",
Name: "example",
LoginType: database.LoginTypeBuiltIn,
HashedPassword: hashed[:],
Username: username,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
})
require.NoError(t, err)
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id,
UserID: user.ID,
HashedSecret: hashed[:],
LastUsed: database.Now(),
ExpiresAt: database.Now().Add(time.Minute),
})
require.NoError(t, err)
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
ID: uuid.New(),
TemplateID: uuid.New(),
OwnerID: user.ID,
Name: "potato",
})
require.NoError(t, err)
build, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{
ID: uuid.New(),
WorkspaceID: workspace.ID,
JobID: uuid.New(),
})
require.NoError(t, err)
job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
ID: build.JobID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
})
require.NoError(t, err)
resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{
ID: uuid.New(),
JobID: job.ID,
})
require.NoError(t, err)
agent, err := db.InsertWorkspaceAgent(context.Background(), database.InsertWorkspaceAgentParams{
ID: uuid.New(),
ResourceID: resource.ID,
})
require.NoError(t, err)
ctx := chi.NewRouteContext()
ctx.URLParams.Add("user", userID.String())
ctx.URLParams.Add("workspaceagent", agent.ID.String())
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
return r, agent
}
t.Run("None", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractWorkspaceBuildParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractWorkspaceAgentParam(db))
rtr.Get("/", nil)
r, _ := setupAuthentication(db)
chi.RouteContext(r.Context()).URLParams.Add("workspaceagent", uuid.NewString())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusNotFound, res.StatusCode)
})
t.Run("WorkspaceAgent", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractAPIKey(db, nil),
httpmw.ExtractWorkspaceAgentParam(db),
)
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
_ = httpmw.WorkspaceAgentParam(r)
rw.WriteHeader(http.StatusOK)
})
r, _ := setupAuthentication(db)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
res := rw.Result()
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
})
}

View File

@ -587,25 +587,21 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
Address: address,
Type: protoResource.Type,
Name: protoResource.Name,
AgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: protoResource.Agent != nil,
},
})
if err != nil {
return xerrors.Errorf("insert provisioner job resource %q: %w", protoResource.Name, err)
}
if resource.AgentID.Valid {
for _, agent := range protoResource.Agents {
var instanceID sql.NullString
if protoResource.Agent.GetInstanceId() != "" {
if agent.GetInstanceId() != "" {
instanceID = sql.NullString{
String: protoResource.Agent.GetInstanceId(),
String: agent.GetInstanceId(),
Valid: true,
}
}
var env pqtype.NullRawMessage
if protoResource.Agent.Env != nil {
data, err := json.Marshal(protoResource.Agent.Env)
if agent.Env != nil {
data, err := json.Marshal(agent.Env)
if err != nil {
return xerrors.Errorf("marshal env: %w", err)
}
@ -615,24 +611,27 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
}
authToken := uuid.New()
if protoResource.Agent.GetToken() != "" {
authToken, err = uuid.Parse(protoResource.Agent.GetToken())
if agent.GetToken() != "" {
authToken, err = uuid.Parse(agent.GetToken())
if err != nil {
return xerrors.Errorf("invalid auth token format; must be uuid: %w", err)
}
}
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
ID: resource.AgentID.UUID,
ID: uuid.New(),
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
ResourceID: resource.ID,
Name: agent.Name,
AuthToken: authToken,
AuthInstanceID: instanceID,
Architecture: agent.Architecture,
EnvironmentVariables: env,
OperatingSystem: agent.OperatingSystem,
StartupScript: sql.NullString{
String: protoResource.Agent.StartupScript,
Valid: protoResource.Agent.StartupScript != "",
String: agent.StartupScript,
Valid: agent.StartupScript != "",
},
})
if err != nil {

View File

@ -196,27 +196,38 @@ func (api *api) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
})
return
}
resourceIDs := make([]uuid.UUID, 0)
for _, resource := range resources {
resourceIDs = append(resourceIDs, resource.ID)
}
resourceAgents, err := api.Database.GetWorkspaceAgentsByResourceIDs(r.Context(), resourceIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace agents by resources: %s", err),
})
return
}
apiResources := make([]codersdk.WorkspaceResource, 0)
for _, resource := range resources {
if !resource.AgentID.Valid {
apiResources = append(apiResources, convertWorkspaceResource(resource, nil))
continue
agents := make([]codersdk.WorkspaceAgent, 0)
for _, agent := range resourceAgents {
if agent.ResourceID != resource.ID {
continue
}
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert provisioner job agent: %s", err),
})
return
}
agents = append(agents, apiAgent)
}
agent, err := api.Database.GetWorkspaceAgentByResourceID(r.Context(), resource.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get provisioner job agent: %s", err),
})
return
}
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert provisioner job agent: %s", err),
})
return
}
apiResources = append(apiResources, convertWorkspaceResource(resource, &apiAgent))
apiResources = append(apiResources, convertWorkspaceResource(resource, agents))
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, apiResources)

View File

@ -210,10 +210,10 @@ func TestTemplateVersionResources(t *testing.T) {
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agent: &proto.Agent{
Agents: []*proto.Agent{{
Id: "something",
Auth: &proto.Agent_Token{},
},
}},
}, {
Name: "another",
Type: "example",
@ -229,7 +229,7 @@ func TestTemplateVersionResources(t *testing.T) {
require.Len(t, resources, 4)
require.Equal(t, "some", resources[0].Name)
require.Equal(t, "example", resources[0].Type)
require.NotNil(t, resources[0].Agent)
require.Len(t, resources[0].Agents, 1)
})
}
@ -255,12 +255,12 @@ func TestTemplateVersionLogs(t *testing.T) {
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agent: &proto.Agent{
Agents: []*proto.Agent{{
Id: "something",
Auth: &proto.Agent_Token{
Token: uuid.NewString(),
},
},
}},
}, {
Name: "another",
Type: "example",

257
coderd/workspaceagents.go Normal file
View File

@ -0,0 +1,257 @@
package coderd
import (
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/go-chi/render"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func (api *api) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
agent := httpmw.WorkspaceAgentParam(r)
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspace agent: %s", err),
})
return
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, apiAgent)
}
func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
agent := httpmw.WorkspaceAgentParam(r)
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
})
return
}
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
ChannelID: agent.ID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
}
func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
agent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
})
return
}
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
})
return
}
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
ChannelID: agent.ID.String(),
Pubsub: api.Pubsub,
Logger: api.Logger.Named("peerbroker-proxy-listen"),
})
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
defer closer.Close()
firstConnectedAt := agent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
}
lastConnectedAt := sql.NullTime{
Time: database.Now(),
Valid: true,
}
disconnectedAt := agent.DisconnectedAt
updateConnectionTimes := func() error {
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
ID: agent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
})
if err != nil {
return err
}
return nil
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDWithoutAfter(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID.String() != latestBuild.ID.String() {
return xerrors.New("build is outdated")
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
_ = updateConnectionTimes()
}()
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
for {
select {
case <-session.CloseChan():
return
case <-ticker.C:
lastConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
}
}
}
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
}
}
agent := codersdk.WorkspaceAgent{
ID: dbAgent.ID,
CreatedAt: dbAgent.CreatedAt,
UpdatedAt: dbAgent.UpdatedAt,
ResourceID: dbAgent.ResourceID,
InstanceID: dbAgent.AuthInstanceID.String,
Name: dbAgent.Name,
Architecture: dbAgent.Architecture,
OperatingSystem: dbAgent.OperatingSystem,
StartupScript: dbAgent.StartupScript.String,
EnvironmentVariables: envs,
}
if dbAgent.FirstConnectedAt.Valid {
agent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
}
if dbAgent.LastConnectedAt.Valid {
agent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
}
if dbAgent.DisconnectedAt.Valid {
agent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
}
switch {
case !dbAgent.FirstConnectedAt.Valid:
// If the agent never connected, it's waiting for the compute
// to start up.
agent.Status = codersdk.WorkspaceAgentWaiting
case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time):
// If we've disconnected after our last connection, we know the
// agent is no longer connected.
agent.Status = codersdk.WorkspaceAgentDisconnected
case agentUpdateFrequency*2 >= database.Now().Sub(dbAgent.LastConnectedAt.Time):
// The connection updated it's timestamp within the update frequency.
// We multiply by two to allow for some lag.
agent.Status = codersdk.WorkspaceAgentConnected
case database.Now().Sub(dbAgent.LastConnectedAt.Time) > agentUpdateFrequency*2:
// The connection died without updating the last connected.
agent.Status = codersdk.WorkspaceAgentDisconnected
}
return agent, nil
}

View File

@ -0,0 +1,108 @@
package coderd_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)
func TestWorkspaceAgent(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
require.NoError(t, err)
_, err = client.WorkspaceAgent(context.Background(), resources[0].Agents[0].ID)
require.NoError(t, err)
}
func TestWorkspaceAgentListen(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
_, err = conn.Ping()
require.NoError(t, err)
}

View File

@ -106,7 +106,7 @@ func convertWorkspaceBuild(workspaceBuild database.WorkspaceBuild, job codersdk.
}
}
func convertWorkspaceResource(resource database.WorkspaceResource, agent *codersdk.WorkspaceAgent) codersdk.WorkspaceResource {
func convertWorkspaceResource(resource database.WorkspaceResource, agents []codersdk.WorkspaceAgent) codersdk.WorkspaceResource {
return codersdk.WorkspaceResource{
ID: resource.ID,
CreatedAt: resource.CreatedAt,
@ -115,6 +115,6 @@ func convertWorkspaceResource(resource database.WorkspaceResource, agent *coders
Address: resource.Address,
Type: resource.Type,
Name: resource.Name,
Agent: agent,
Agents: agents,
}
}

View File

@ -91,10 +91,10 @@ func TestWorkspaceBuildResources(t *testing.T) {
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agent: &proto.Agent{
Agents: []*proto.Agent{{
Id: "something",
Auth: &proto.Agent_Token{},
},
}},
}, {
Name: "another",
Type: "example",
@ -113,7 +113,7 @@ func TestWorkspaceBuildResources(t *testing.T) {
require.Len(t, resources, 2)
require.Equal(t, "some", resources[0].Name)
require.Equal(t, "example", resources[0].Type)
require.NotNil(t, resources[0].Agent)
require.Len(t, resources[0].Agents, 1)
})
}
@ -138,10 +138,10 @@ func TestWorkspaceBuildLogs(t *testing.T) {
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agent: &proto.Agent{
Agents: []*proto.Agent{{
Id: "something",
Auth: &proto.Agent_Token{},
},
}},
}, {
Name: "another",
Type: "example",

View File

@ -32,11 +32,11 @@ func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) {
Resources: []*proto.Resource{{
Name: "somename",
Type: "someinstance",
Agent: &proto.Agent{
Agents: []*proto.Agent{{
Auth: &proto.Agent_InstanceId{
InstanceId: instanceID,
},
},
}},
}},
},
},
@ -98,11 +98,11 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) {
Resources: []*proto.Resource{{
Name: "somename",
Type: "someinstance",
Agent: &proto.Agent{
Agents: []*proto.Agent{{
Auth: &proto.Agent_InstanceId{
InstanceId: instanceID,
},
},
}},
}},
},
},

View File

@ -2,26 +2,16 @@ package coderd
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/go-chi/render"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)
func (api *api) workspaceResource(rw http.ResponseWriter, r *http.Request) {
@ -40,15 +30,18 @@ func (api *api) workspaceResource(rw http.ResponseWriter, r *http.Request) {
})
return
}
var apiAgent *codersdk.WorkspaceAgent
if workspaceResource.AgentID.Valid {
agent, err := api.Database.GetWorkspaceAgentByResourceID(r.Context(), workspaceResource.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get provisioner job agent: %s", err),
})
return
}
agents, err := api.Database.GetWorkspaceAgentsByResourceIDs(r.Context(), []uuid.UUID{workspaceResource.ID})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get provisioner job agents: %s", err),
})
return
}
apiAgents := make([]codersdk.WorkspaceAgent, 0)
for _, agent := range agents {
convertedAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
@ -56,239 +49,9 @@ func (api *api) workspaceResource(rw http.ResponseWriter, r *http.Request) {
})
return
}
apiAgent = &convertedAgent
apiAgents = append(apiAgents, convertedAgent)
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, convertWorkspaceResource(workspaceResource, apiAgent))
}
func (api *api) workspaceResourceDial(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
resource := httpmw.WorkspaceResourceParam(r)
if !resource.AgentID.Valid {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "resource doesn't have an agent",
})
return
}
agent, err := api.Database.GetWorkspaceAgentByResourceID(r.Context(), resource.ID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("get provisioner job agent: %s", err),
})
return
}
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
})
return
}
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
ChannelID: agent.ID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
}
func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
agent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
})
return
}
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
})
return
}
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
ChannelID: agent.ID.String(),
Pubsub: api.Pubsub,
Logger: api.Logger.Named("peerbroker-proxy-listen"),
})
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
defer closer.Close()
firstConnectedAt := agent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
}
lastConnectedAt := sql.NullTime{
Time: database.Now(),
Valid: true,
}
disconnectedAt := agent.DisconnectedAt
updateConnectionTimes := func() error {
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
ID: agent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
})
if err != nil {
return err
}
return nil
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDWithoutAfter(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID.String() != latestBuild.ID.String() {
return xerrors.New("build is outdated")
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
_ = updateConnectionTimes()
}()
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
for {
select {
case <-session.CloseChan():
return
case <-ticker.C:
lastConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
}
}
}
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
}
}
agent := codersdk.WorkspaceAgent{
ID: dbAgent.ID,
CreatedAt: dbAgent.CreatedAt,
UpdatedAt: dbAgent.UpdatedAt,
ResourceID: dbAgent.ResourceID,
InstanceID: dbAgent.AuthInstanceID.String,
StartupScript: dbAgent.StartupScript.String,
EnvironmentVariables: envs,
}
if dbAgent.FirstConnectedAt.Valid {
agent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
}
if dbAgent.LastConnectedAt.Valid {
agent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
}
if dbAgent.DisconnectedAt.Valid {
agent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
}
switch {
case !dbAgent.FirstConnectedAt.Valid:
// If the agent never connected, it's waiting for the compute
// to start up.
agent.Status = codersdk.WorkspaceAgentWaiting
case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time):
// If we've disconnected after our last connection, we know the
// agent is no longer connected.
agent.Status = codersdk.WorkspaceAgentDisconnected
case agentUpdateFrequency*2 >= database.Now().Sub(dbAgent.LastConnectedAt.Time):
// The connection updated it's timestamp within the update frequency.
// We multiply by two to allow for some lag.
agent.Status = codersdk.WorkspaceAgentConnected
case database.Now().Sub(dbAgent.LastConnectedAt.Time) > agentUpdateFrequency*2:
// The connection died without updating the last connected.
agent.Status = codersdk.WorkspaceAgentDisconnected
}
return agent, nil
render.JSON(rw, r, convertWorkspaceResource(workspaceResource, apiAgents))
}

View File

@ -4,16 +4,10 @@ import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)
@ -33,10 +27,10 @@ func TestWorkspaceResource(t *testing.T) {
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agent: &proto.Agent{
Agents: []*proto.Agent{{
Id: "something",
Auth: &proto.Agent_Token{},
},
}},
}},
},
},
@ -52,55 +46,3 @@ func TestWorkspaceResource(t *testing.T) {
require.NoError(t, err)
})
}
func TestWorkspaceAgentListen(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agent: &proto.Agent{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{
Logger: slogtest.Make(t, nil),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].ID, nil, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
_, err = conn.Ping()
require.NoError(t, err)
}