mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: Allow workspace resources to attach multiple agents (#942)
This enables a "kubernetes_pod" to attach multiple agents that could be for multiple services. Each agent is required to have a unique name, so SSH syntax is: `coder ssh <workspace>.<agent>` A resource can have zero agents too, they aren't required.
This commit is contained in:
@ -168,26 +168,31 @@ func New(options *Options) (http.Handler, func()) {
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/workspaceresources", func(r chi.Router) {
|
||||
r.Route("/auth", func(r chi.Router) {
|
||||
r.Post("/aws-instance-identity", api.postWorkspaceAuthAWSInstanceIdentity)
|
||||
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
|
||||
})
|
||||
r.Route("/agent", func(r chi.Router) {
|
||||
r.Route("/workspaceagents", func(r chi.Router) {
|
||||
r.Post("/aws-instance-identity", api.postWorkspaceAuthAWSInstanceIdentity)
|
||||
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
|
||||
r.Route("/me", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
|
||||
r.Get("/", api.workspaceAgentListen)
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
})
|
||||
r.Route("/{workspaceresource}", func(r chi.Router) {
|
||||
r.Route("/{workspaceagent}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
httpmw.ExtractWorkspaceResourceParam(options.Database),
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
httpmw.ExtractWorkspaceAgentParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspaceResource)
|
||||
r.Get("/dial", api.workspaceResourceDial)
|
||||
r.Get("/", api.workspaceAgent)
|
||||
r.Get("/dial", api.workspaceAgentDial)
|
||||
})
|
||||
})
|
||||
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
httpmw.ExtractWorkspaceResourceParam(options.Database),
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspaceResource)
|
||||
})
|
||||
r.Route("/workspaces/{workspace}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractAPIKey(options.Database, nil),
|
||||
|
@ -263,11 +263,10 @@ func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, build uuid.UUID
|
||||
resources, err = client.WorkspaceResourcesByBuild(context.Background(), build)
|
||||
require.NoError(t, err)
|
||||
for _, resource := range resources {
|
||||
if resource.Agent == nil {
|
||||
continue
|
||||
}
|
||||
if resource.Agent.FirstConnectedAt == nil {
|
||||
return false
|
||||
for _, agent := range resource.Agents {
|
||||
if agent.FirstConnectedAt == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
@ -634,6 +634,20 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
|
||||
return database.WorkspaceAgent{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
// The schema sorts this by created at, so we iterate the array backwards.
|
||||
for i := len(q.provisionerJobAgent) - 1; i >= 0; i-- {
|
||||
agent := q.provisionerJobAgent[i]
|
||||
if agent.ID.String() == id.String() {
|
||||
return agent, nil
|
||||
}
|
||||
}
|
||||
return database.WorkspaceAgent{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceID string) (database.WorkspaceAgent, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
@ -648,16 +662,23 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI
|
||||
return database.WorkspaceAgent{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetWorkspaceAgentByResourceID(_ context.Context, resourceID uuid.UUID) (database.WorkspaceAgent, error) {
|
||||
func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourceIDs []uuid.UUID) ([]database.WorkspaceAgent, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
workspaceAgents := make([]database.WorkspaceAgent, 0)
|
||||
for _, agent := range q.provisionerJobAgent {
|
||||
if agent.ResourceID.String() == resourceID.String() {
|
||||
return agent, nil
|
||||
for _, resourceID := range resourceIDs {
|
||||
if agent.ResourceID.String() != resourceID.String() {
|
||||
continue
|
||||
}
|
||||
workspaceAgents = append(workspaceAgents, agent)
|
||||
}
|
||||
}
|
||||
return database.WorkspaceAgent{}, sql.ErrNoRows
|
||||
if len(workspaceAgents) == 0 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
return workspaceAgents, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetProvisionerDaemonByID(_ context.Context, id uuid.UUID) (database.ProvisionerDaemon, error) {
|
||||
@ -982,6 +1003,9 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser
|
||||
AuthToken: arg.AuthToken,
|
||||
AuthInstanceID: arg.AuthInstanceID,
|
||||
EnvironmentVariables: arg.EnvironmentVariables,
|
||||
Name: arg.Name,
|
||||
Architecture: arg.Architecture,
|
||||
OperatingSystem: arg.OperatingSystem,
|
||||
StartupScript: arg.StartupScript,
|
||||
InstanceMetadata: arg.InstanceMetadata,
|
||||
ResourceMetadata: arg.ResourceMetadata,
|
||||
@ -1003,7 +1027,6 @@ func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.In
|
||||
Address: arg.Address,
|
||||
Type: arg.Type,
|
||||
Name: arg.Name,
|
||||
AgentID: arg.AgentID,
|
||||
}
|
||||
q.provisionerJobResource = append(q.provisionerJobResource, resource)
|
||||
return resource, nil
|
||||
|
6
coderd/database/dump.sql
generated
6
coderd/database/dump.sql
generated
@ -235,13 +235,16 @@ CREATE TABLE workspace_agents (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
first_connected_at timestamp with time zone,
|
||||
last_connected_at timestamp with time zone,
|
||||
disconnected_at timestamp with time zone,
|
||||
resource_id uuid NOT NULL,
|
||||
auth_token uuid NOT NULL,
|
||||
auth_instance_id character varying(64),
|
||||
architecture character varying(64) NOT NULL,
|
||||
environment_variables jsonb,
|
||||
operating_system character varying(64) NOT NULL,
|
||||
startup_script character varying(65534),
|
||||
instance_metadata jsonb,
|
||||
resource_metadata jsonb
|
||||
@ -269,8 +272,7 @@ CREATE TABLE workspace_resources (
|
||||
transition workspace_transition NOT NULL,
|
||||
address character varying(256) NOT NULL,
|
||||
type character varying(192) NOT NULL,
|
||||
name character varying(64) NOT NULL,
|
||||
agent_id uuid
|
||||
name character varying(64) NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE workspaces (
|
||||
|
@ -69,7 +69,6 @@ CREATE TABLE workspace_resources (
|
||||
address varchar(256) NOT NULL,
|
||||
type varchar(192) NOT NULL,
|
||||
name varchar(64) NOT NULL,
|
||||
agent_id uuid,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
@ -77,13 +76,16 @@ CREATE TABLE workspace_agents (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamptz NOT NULL,
|
||||
updated_at timestamptz NOT NULL,
|
||||
name varchar(64) NOT NULL,
|
||||
first_connected_at timestamptz,
|
||||
last_connected_at timestamptz,
|
||||
disconnected_at timestamptz,
|
||||
resource_id uuid NOT NULL REFERENCES workspace_resources (id) ON DELETE CASCADE,
|
||||
auth_token uuid NOT NULL UNIQUE,
|
||||
auth_instance_id varchar(64),
|
||||
architecture varchar(64) NOT NULL,
|
||||
environment_variables jsonb,
|
||||
operating_system varchar(64) NOT NULL,
|
||||
startup_script varchar(65534),
|
||||
instance_metadata jsonb,
|
||||
resource_metadata jsonb,
|
||||
|
@ -403,13 +403,16 @@ type WorkspaceAgent struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
|
||||
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
|
||||
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
|
||||
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
|
||||
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
|
||||
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
|
||||
Architecture string `db:"architecture" json:"architecture"`
|
||||
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
|
||||
OperatingSystem string `db:"operating_system" json:"operating_system"`
|
||||
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
|
||||
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
|
||||
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
|
||||
@ -438,5 +441,4 @@ type WorkspaceResource struct {
|
||||
Address string `db:"address" json:"address"`
|
||||
Type string `db:"type" json:"type"`
|
||||
Name string `db:"name" json:"name"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
}
|
||||
|
@ -39,8 +39,9 @@ type querier interface {
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserCount(ctx context.Context) (int64, error)
|
||||
GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error)
|
||||
GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error)
|
||||
GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error)
|
||||
GetWorkspaceAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (WorkspaceAgent, error)
|
||||
GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error)
|
||||
GetWorkspaceBuildByID(ctx context.Context, id uuid.UUID) (WorkspaceBuild, error)
|
||||
GetWorkspaceBuildByJobID(ctx context.Context, jobID uuid.UUID) (WorkspaceBuild, error)
|
||||
GetWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
|
@ -1907,7 +1907,7 @@ func (q *sqlQuerier) InsertUser(ctx context.Context, arg InsertUserParams) (User
|
||||
|
||||
const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one
|
||||
SELECT
|
||||
id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
|
||||
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
@ -1923,13 +1923,49 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.FirstConnectedAt,
|
||||
&i.LastConnectedAt,
|
||||
&i.DisconnectedAt,
|
||||
&i.ResourceID,
|
||||
&i.AuthToken,
|
||||
&i.AuthInstanceID,
|
||||
&i.Architecture,
|
||||
&i.EnvironmentVariables,
|
||||
&i.OperatingSystem,
|
||||
&i.StartupScript,
|
||||
&i.InstanceMetadata,
|
||||
&i.ResourceMetadata,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceAgentByID = `-- name: GetWorkspaceAgentByID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) {
|
||||
row := q.db.QueryRowContext(ctx, getWorkspaceAgentByID, id)
|
||||
var i WorkspaceAgent
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.FirstConnectedAt,
|
||||
&i.LastConnectedAt,
|
||||
&i.DisconnectedAt,
|
||||
&i.ResourceID,
|
||||
&i.AuthToken,
|
||||
&i.AuthInstanceID,
|
||||
&i.Architecture,
|
||||
&i.EnvironmentVariables,
|
||||
&i.OperatingSystem,
|
||||
&i.StartupScript,
|
||||
&i.InstanceMetadata,
|
||||
&i.ResourceMetadata,
|
||||
@ -1939,7 +1975,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
|
||||
|
||||
const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one
|
||||
SELECT
|
||||
id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
|
||||
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
@ -1955,13 +1991,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.FirstConnectedAt,
|
||||
&i.LastConnectedAt,
|
||||
&i.DisconnectedAt,
|
||||
&i.ResourceID,
|
||||
&i.AuthToken,
|
||||
&i.AuthInstanceID,
|
||||
&i.Architecture,
|
||||
&i.EnvironmentVariables,
|
||||
&i.OperatingSystem,
|
||||
&i.StartupScript,
|
||||
&i.InstanceMetadata,
|
||||
&i.ResourceMetadata,
|
||||
@ -1969,34 +2008,53 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceAgentByResourceID = `-- name: GetWorkspaceAgentByResourceID :one
|
||||
const getWorkspaceAgentsByResourceIDs = `-- name: GetWorkspaceAgentsByResourceIDs :many
|
||||
SELECT
|
||||
id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
|
||||
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
resource_id = $1
|
||||
resource_id = ANY($1 :: uuid [ ])
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetWorkspaceAgentByResourceID(ctx context.Context, resourceID uuid.UUID) (WorkspaceAgent, error) {
|
||||
row := q.db.QueryRowContext(ctx, getWorkspaceAgentByResourceID, resourceID)
|
||||
var i WorkspaceAgent
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.FirstConnectedAt,
|
||||
&i.LastConnectedAt,
|
||||
&i.DisconnectedAt,
|
||||
&i.ResourceID,
|
||||
&i.AuthToken,
|
||||
&i.AuthInstanceID,
|
||||
&i.EnvironmentVariables,
|
||||
&i.StartupScript,
|
||||
&i.InstanceMetadata,
|
||||
&i.ResourceMetadata,
|
||||
)
|
||||
return i, err
|
||||
func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []uuid.UUID) ([]WorkspaceAgent, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getWorkspaceAgentsByResourceIDs, pq.Array(ids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []WorkspaceAgent
|
||||
for rows.Next() {
|
||||
var i WorkspaceAgent
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.FirstConnectedAt,
|
||||
&i.LastConnectedAt,
|
||||
&i.DisconnectedAt,
|
||||
&i.ResourceID,
|
||||
&i.AuthToken,
|
||||
&i.AuthInstanceID,
|
||||
&i.Architecture,
|
||||
&i.EnvironmentVariables,
|
||||
&i.OperatingSystem,
|
||||
&i.StartupScript,
|
||||
&i.InstanceMetadata,
|
||||
&i.ResourceMetadata,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertWorkspaceAgent = `-- name: InsertWorkspaceAgent :one
|
||||
@ -2005,26 +2063,32 @@ INSERT INTO
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
name,
|
||||
resource_id,
|
||||
auth_token,
|
||||
auth_instance_id,
|
||||
architecture,
|
||||
environment_variables,
|
||||
operating_system,
|
||||
startup_script,
|
||||
instance_metadata,
|
||||
resource_metadata
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, environment_variables, startup_script, instance_metadata, resource_metadata
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata
|
||||
`
|
||||
|
||||
type InsertWorkspaceAgentParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Name string `db:"name" json:"name"`
|
||||
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
|
||||
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
|
||||
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
|
||||
Architecture string `db:"architecture" json:"architecture"`
|
||||
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
|
||||
OperatingSystem string `db:"operating_system" json:"operating_system"`
|
||||
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
|
||||
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
|
||||
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
|
||||
@ -2035,10 +2099,13 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
|
||||
arg.ID,
|
||||
arg.CreatedAt,
|
||||
arg.UpdatedAt,
|
||||
arg.Name,
|
||||
arg.ResourceID,
|
||||
arg.AuthToken,
|
||||
arg.AuthInstanceID,
|
||||
arg.Architecture,
|
||||
arg.EnvironmentVariables,
|
||||
arg.OperatingSystem,
|
||||
arg.StartupScript,
|
||||
arg.InstanceMetadata,
|
||||
arg.ResourceMetadata,
|
||||
@ -2048,13 +2115,16 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.Name,
|
||||
&i.FirstConnectedAt,
|
||||
&i.LastConnectedAt,
|
||||
&i.DisconnectedAt,
|
||||
&i.ResourceID,
|
||||
&i.AuthToken,
|
||||
&i.AuthInstanceID,
|
||||
&i.Architecture,
|
||||
&i.EnvironmentVariables,
|
||||
&i.OperatingSystem,
|
||||
&i.StartupScript,
|
||||
&i.InstanceMetadata,
|
||||
&i.ResourceMetadata,
|
||||
@ -2405,7 +2475,7 @@ func (q *sqlQuerier) UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWor
|
||||
|
||||
const getWorkspaceResourceByID = `-- name: GetWorkspaceResourceByID :one
|
||||
SELECT
|
||||
id, created_at, job_id, transition, address, type, name, agent_id
|
||||
id, created_at, job_id, transition, address, type, name
|
||||
FROM
|
||||
workspace_resources
|
||||
WHERE
|
||||
@ -2423,14 +2493,13 @@ func (q *sqlQuerier) GetWorkspaceResourceByID(ctx context.Context, id uuid.UUID)
|
||||
&i.Address,
|
||||
&i.Type,
|
||||
&i.Name,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getWorkspaceResourcesByJobID = `-- name: GetWorkspaceResourcesByJobID :many
|
||||
SELECT
|
||||
id, created_at, job_id, transition, address, type, name, agent_id
|
||||
id, created_at, job_id, transition, address, type, name
|
||||
FROM
|
||||
workspace_resources
|
||||
WHERE
|
||||
@ -2454,7 +2523,6 @@ func (q *sqlQuerier) GetWorkspaceResourcesByJobID(ctx context.Context, jobID uui
|
||||
&i.Address,
|
||||
&i.Type,
|
||||
&i.Name,
|
||||
&i.AgentID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -2478,11 +2546,10 @@ INSERT INTO
|
||||
transition,
|
||||
address,
|
||||
type,
|
||||
name,
|
||||
agent_id
|
||||
name
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, job_id, transition, address, type, name, agent_id
|
||||
($1, $2, $3, $4, $5, $6, $7) RETURNING id, created_at, job_id, transition, address, type, name
|
||||
`
|
||||
|
||||
type InsertWorkspaceResourceParams struct {
|
||||
@ -2493,7 +2560,6 @@ type InsertWorkspaceResourceParams struct {
|
||||
Address string `db:"address" json:"address"`
|
||||
Type string `db:"type" json:"type"`
|
||||
Name string `db:"name" json:"name"`
|
||||
AgentID uuid.NullUUID `db:"agent_id" json:"agent_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) {
|
||||
@ -2505,7 +2571,6 @@ func (q *sqlQuerier) InsertWorkspaceResource(ctx context.Context, arg InsertWork
|
||||
arg.Address,
|
||||
arg.Type,
|
||||
arg.Name,
|
||||
arg.AgentID,
|
||||
)
|
||||
var i WorkspaceResource
|
||||
err := row.Scan(
|
||||
@ -2516,7 +2581,6 @@ func (q *sqlQuerier) InsertWorkspaceResource(ctx context.Context, arg InsertWork
|
||||
&i.Address,
|
||||
&i.Type,
|
||||
&i.Name,
|
||||
&i.AgentID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
@ -8,6 +8,14 @@ WHERE
|
||||
ORDER BY
|
||||
created_at DESC;
|
||||
|
||||
-- name: GetWorkspaceAgentByID :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetWorkspaceAgentByInstanceID :one
|
||||
SELECT
|
||||
*
|
||||
@ -18,13 +26,13 @@ WHERE
|
||||
ORDER BY
|
||||
created_at DESC;
|
||||
|
||||
-- name: GetWorkspaceAgentByResourceID :one
|
||||
-- name: GetWorkspaceAgentsByResourceIDs :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
workspace_agents
|
||||
WHERE
|
||||
resource_id = $1;
|
||||
resource_id = ANY(@ids :: uuid [ ]);
|
||||
|
||||
-- name: InsertWorkspaceAgent :one
|
||||
INSERT INTO
|
||||
@ -32,16 +40,19 @@ INSERT INTO
|
||||
id,
|
||||
created_at,
|
||||
updated_at,
|
||||
name,
|
||||
resource_id,
|
||||
auth_token,
|
||||
auth_instance_id,
|
||||
architecture,
|
||||
environment_variables,
|
||||
operating_system,
|
||||
startup_script,
|
||||
instance_metadata,
|
||||
resource_metadata
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING *;
|
||||
|
||||
-- name: UpdateWorkspaceAgentConnectionByID :exec
|
||||
UPDATE
|
||||
|
@ -23,8 +23,7 @@ INSERT INTO
|
||||
transition,
|
||||
address,
|
||||
type,
|
||||
name,
|
||||
agent_id
|
||||
name
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *;
|
||||
($1, $2, $3, $4, $5, $6, $7) RETURNING *;
|
||||
|
@ -79,51 +79,40 @@ func TestGitSSHKey(t *testing.T) {
|
||||
func TestAgentGitSSHKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
agentClient := func(algo gitsshkey.Algorithm) *codersdk.Client {
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
SSHKeygenAlgorithm: algo,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agent: &proto.Agent{
|
||||
Id: uuid.NewString(),
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}},
|
||||
})
|
||||
project := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, project.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
daemonCloser.Close()
|
||||
|
||||
agentClient := codersdk.New(client.URL)
|
||||
agentClient.SessionToken = authToken
|
||||
|
||||
return agentClient
|
||||
}
|
||||
|
||||
t.Run("AgentKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
client := agentClient(gitsshkey.AlgorithmEd25519)
|
||||
agentKey, err := client.AgentGitSSHKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, agentKey.PrivateKey)
|
||||
},
|
||||
}},
|
||||
})
|
||||
project := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, project.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
daemonCloser.Close()
|
||||
|
||||
agentClient := codersdk.New(client.URL)
|
||||
agentClient.SessionToken = authToken
|
||||
|
||||
agentKey, err := agentClient.AgentGitSSHKey(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, agentKey.PrivateKey)
|
||||
}
|
||||
|
94
coderd/httpmw/workspaceagentparam.go
Normal file
94
coderd/httpmw/workspaceagentparam.go
Normal file
@ -0,0 +1,94 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
type workspaceAgentParamContextKey struct{}
|
||||
|
||||
// WorkspaceAgentParam returns the workspace agent from the ExtractWorkspaceAgentParam handler.
|
||||
func WorkspaceAgentParam(r *http.Request) database.WorkspaceAgent {
|
||||
user, ok := r.Context().Value(workspaceAgentParamContextKey{}).(database.WorkspaceAgent)
|
||||
if !ok {
|
||||
panic("developer error: agent middleware not provided")
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// ExtractWorkspaceAgentParam grabs a workspace agent from the "workspaceagent" URL parameter.
|
||||
func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
agentUUID, parsed := parseUUID(rw, r, "workspaceagent")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
agent, err := db.GetWorkspaceAgentByID(r.Context(), agentUUID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
|
||||
Message: "agent doesn't exist with that id",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
resource, err := db.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get resource: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
job, err := db.GetProvisionerJobByID(r.Context(), resource.JobID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get job: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
if job.Type != database.ProvisionerJobTypeWorkspaceBuild {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "Workspace agents can only be fetched for builds.",
|
||||
})
|
||||
return
|
||||
}
|
||||
build, err := db.GetWorkspaceBuildByJobID(r.Context(), job.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get workspace build: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
workspace, err := db.GetWorkspaceByID(r.Context(), build.WorkspaceID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get workspace: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := APIKey(r)
|
||||
if apiKey.UserID != workspace.OwnerID {
|
||||
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
|
||||
Message: "getting non-personal agents isn't supported",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), workspaceAgentParamContextKey{}, agent)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
153
coderd/httpmw/workspaceagentparam_test.go
Normal file
153
coderd/httpmw/workspaceagentparam_test.go
Normal file
@ -0,0 +1,153 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
func TestWorkspaceAgentParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupAuthentication := func(db database.Store) (*http.Request, database.WorkspaceAgent) {
|
||||
var (
|
||||
id, secret = randomAPIKeyParts()
|
||||
hashed = sha256.Sum256([]byte(secret))
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.AddCookie(&http.Cookie{
|
||||
Name: httpmw.AuthCookie,
|
||||
Value: fmt.Sprintf("%s-%s", id, secret),
|
||||
})
|
||||
|
||||
userID := uuid.New()
|
||||
username, err := cryptorand.String(8)
|
||||
require.NoError(t, err)
|
||||
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
|
||||
ID: userID,
|
||||
Email: "testaccount@coder.com",
|
||||
Name: "example",
|
||||
LoginType: database.LoginTypeBuiltIn,
|
||||
HashedPassword: hashed[:],
|
||||
Username: username,
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
UserID: user.ID,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now(),
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
|
||||
ID: uuid.New(),
|
||||
TemplateID: uuid.New(),
|
||||
OwnerID: user.ID,
|
||||
Name: "potato",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
build, err := db.InsertWorkspaceBuild(context.Background(), database.InsertWorkspaceBuildParams{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: workspace.ID,
|
||||
JobID: uuid.New(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
|
||||
ID: build.JobID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := db.InsertWorkspaceResource(context.Background(), database.InsertWorkspaceResourceParams{
|
||||
ID: uuid.New(),
|
||||
JobID: job.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
agent, err := db.InsertWorkspaceAgent(context.Background(), database.InsertWorkspaceAgentParams{
|
||||
ID: uuid.New(),
|
||||
ResourceID: resource.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := chi.NewRouteContext()
|
||||
ctx.URLParams.Add("user", userID.String())
|
||||
ctx.URLParams.Add("workspaceagent", agent.ID.String())
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx))
|
||||
return r, agent
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractWorkspaceBuildParam(db))
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setupAuthentication(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusBadRequest, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractWorkspaceAgentParam(db))
|
||||
rtr.Get("/", nil)
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
chi.RouteContext(r.Context()).URLParams.Add("workspaceagent", uuid.NewString())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractAPIKey(db, nil),
|
||||
httpmw.ExtractWorkspaceAgentParam(db),
|
||||
)
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
_ = httpmw.WorkspaceAgentParam(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
r, _ := setupAuthentication(db)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
}
|
@ -587,25 +587,21 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
|
||||
Address: address,
|
||||
Type: protoResource.Type,
|
||||
Name: protoResource.Name,
|
||||
AgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: protoResource.Agent != nil,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert provisioner job resource %q: %w", protoResource.Name, err)
|
||||
}
|
||||
if resource.AgentID.Valid {
|
||||
for _, agent := range protoResource.Agents {
|
||||
var instanceID sql.NullString
|
||||
if protoResource.Agent.GetInstanceId() != "" {
|
||||
if agent.GetInstanceId() != "" {
|
||||
instanceID = sql.NullString{
|
||||
String: protoResource.Agent.GetInstanceId(),
|
||||
String: agent.GetInstanceId(),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
var env pqtype.NullRawMessage
|
||||
if protoResource.Agent.Env != nil {
|
||||
data, err := json.Marshal(protoResource.Agent.Env)
|
||||
if agent.Env != nil {
|
||||
data, err := json.Marshal(agent.Env)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal env: %w", err)
|
||||
}
|
||||
@ -615,24 +611,27 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
|
||||
}
|
||||
}
|
||||
authToken := uuid.New()
|
||||
if protoResource.Agent.GetToken() != "" {
|
||||
authToken, err = uuid.Parse(protoResource.Agent.GetToken())
|
||||
if agent.GetToken() != "" {
|
||||
authToken, err = uuid.Parse(agent.GetToken())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("invalid auth token format; must be uuid: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
|
||||
ID: resource.AgentID.UUID,
|
||||
ID: uuid.New(),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
ResourceID: resource.ID,
|
||||
Name: agent.Name,
|
||||
AuthToken: authToken,
|
||||
AuthInstanceID: instanceID,
|
||||
Architecture: agent.Architecture,
|
||||
EnvironmentVariables: env,
|
||||
OperatingSystem: agent.OperatingSystem,
|
||||
StartupScript: sql.NullString{
|
||||
String: protoResource.Agent.StartupScript,
|
||||
Valid: protoResource.Agent.StartupScript != "",
|
||||
String: agent.StartupScript,
|
||||
Valid: agent.StartupScript != "",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -196,27 +196,38 @@ func (api *api) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
|
||||
})
|
||||
return
|
||||
}
|
||||
resourceIDs := make([]uuid.UUID, 0)
|
||||
for _, resource := range resources {
|
||||
resourceIDs = append(resourceIDs, resource.ID)
|
||||
}
|
||||
resourceAgents, err := api.Database.GetWorkspaceAgentsByResourceIDs(r.Context(), resourceIDs)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get workspace agents by resources: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
apiResources := make([]codersdk.WorkspaceResource, 0)
|
||||
for _, resource := range resources {
|
||||
if !resource.AgentID.Valid {
|
||||
apiResources = append(apiResources, convertWorkspaceResource(resource, nil))
|
||||
continue
|
||||
agents := make([]codersdk.WorkspaceAgent, 0)
|
||||
for _, agent := range resourceAgents {
|
||||
if agent.ResourceID != resource.ID {
|
||||
continue
|
||||
}
|
||||
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("convert provisioner job agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
agents = append(agents, apiAgent)
|
||||
}
|
||||
agent, err := api.Database.GetWorkspaceAgentByResourceID(r.Context(), resource.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get provisioner job agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("convert provisioner job agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
apiResources = append(apiResources, convertWorkspaceResource(resource, &apiAgent))
|
||||
apiResources = append(apiResources, convertWorkspaceResource(resource, agents))
|
||||
}
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, apiResources)
|
||||
|
@ -210,10 +210,10 @@ func TestTemplateVersionResources(t *testing.T) {
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agent: &proto.Agent{
|
||||
Agents: []*proto.Agent{{
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
}},
|
||||
}, {
|
||||
Name: "another",
|
||||
Type: "example",
|
||||
@ -229,7 +229,7 @@ func TestTemplateVersionResources(t *testing.T) {
|
||||
require.Len(t, resources, 4)
|
||||
require.Equal(t, "some", resources[0].Name)
|
||||
require.Equal(t, "example", resources[0].Type)
|
||||
require.NotNil(t, resources[0].Agent)
|
||||
require.Len(t, resources[0].Agents, 1)
|
||||
})
|
||||
}
|
||||
|
||||
@ -255,12 +255,12 @@ func TestTemplateVersionLogs(t *testing.T) {
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agent: &proto.Agent{
|
||||
Agents: []*proto.Agent{{
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: uuid.NewString(),
|
||||
},
|
||||
},
|
||||
}},
|
||||
}, {
|
||||
Name: "another",
|
||||
Type: "example",
|
||||
|
257
coderd/workspaceagents.go
Normal file
257
coderd/workspaceagents.go
Normal file
@ -0,0 +1,257 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
"github.com/hashicorp/yamux"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
func (api *api) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
|
||||
agent := httpmw.WorkspaceAgentParam(r)
|
||||
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("convert workspace agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, apiAgent)
|
||||
}
|
||||
|
||||
func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
|
||||
api.websocketWaitMutex.Lock()
|
||||
api.websocketWaitGroup.Add(1)
|
||||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
|
||||
agent := httpmw.WorkspaceAgentParam(r)
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("accept websocket: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
|
||||
ChannelID: agent.ID.String(),
|
||||
Logger: api.Logger.Named("peerbroker-proxy-dial"),
|
||||
Pubsub: api.Pubsub,
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
|
||||
api.websocketWaitMutex.Lock()
|
||||
api.websocketWaitGroup.Add(1)
|
||||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
|
||||
agent := httpmw.WorkspaceAgent(r)
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("accept websocket: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("accept websocket: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
|
||||
ChannelID: agent.ID.String(),
|
||||
Pubsub: api.Pubsub,
|
||||
Logger: api.Logger.Named("peerbroker-proxy-listen"),
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
defer closer.Close()
|
||||
firstConnectedAt := agent.FirstConnectedAt
|
||||
if !firstConnectedAt.Valid {
|
||||
firstConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
lastConnectedAt := sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
disconnectedAt := agent.DisconnectedAt
|
||||
updateConnectionTimes := func() error {
|
||||
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
|
||||
ID: agent.ID,
|
||||
FirstConnectedAt: firstConnectedAt,
|
||||
LastConnectedAt: lastConnectedAt,
|
||||
DisconnectedAt: disconnectedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
// Ensure the resource is still valid!
|
||||
// We only accept agents for resources on the latest build.
|
||||
ensureLatestBuild := func() error {
|
||||
latestBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDWithoutAfter(r.Context(), build.WorkspaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if build.ID.String() != latestBuild.ID.String() {
|
||||
return xerrors.New("build is outdated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
disconnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
_ = updateConnectionTimes()
|
||||
}()
|
||||
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = ensureLatestBuild()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
}
|
||||
|
||||
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
|
||||
|
||||
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-session.CloseChan():
|
||||
return
|
||||
case <-ticker.C:
|
||||
lastConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = ensureLatestBuild()
|
||||
if err != nil {
|
||||
// Disconnect agents that are no longer valid.
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
|
||||
var envs map[string]string
|
||||
if dbAgent.EnvironmentVariables.Valid {
|
||||
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
}
|
||||
agent := codersdk.WorkspaceAgent{
|
||||
ID: dbAgent.ID,
|
||||
CreatedAt: dbAgent.CreatedAt,
|
||||
UpdatedAt: dbAgent.UpdatedAt,
|
||||
ResourceID: dbAgent.ResourceID,
|
||||
InstanceID: dbAgent.AuthInstanceID.String,
|
||||
Name: dbAgent.Name,
|
||||
Architecture: dbAgent.Architecture,
|
||||
OperatingSystem: dbAgent.OperatingSystem,
|
||||
StartupScript: dbAgent.StartupScript.String,
|
||||
EnvironmentVariables: envs,
|
||||
}
|
||||
if dbAgent.FirstConnectedAt.Valid {
|
||||
agent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
|
||||
}
|
||||
if dbAgent.LastConnectedAt.Valid {
|
||||
agent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
|
||||
}
|
||||
if dbAgent.DisconnectedAt.Valid {
|
||||
agent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
|
||||
}
|
||||
switch {
|
||||
case !dbAgent.FirstConnectedAt.Valid:
|
||||
// If the agent never connected, it's waiting for the compute
|
||||
// to start up.
|
||||
agent.Status = codersdk.WorkspaceAgentWaiting
|
||||
case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time):
|
||||
// If we've disconnected after our last connection, we know the
|
||||
// agent is no longer connected.
|
||||
agent.Status = codersdk.WorkspaceAgentDisconnected
|
||||
case agentUpdateFrequency*2 >= database.Now().Sub(dbAgent.LastConnectedAt.Time):
|
||||
// The connection updated it's timestamp within the update frequency.
|
||||
// We multiply by two to allow for some lag.
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
case database.Now().Sub(dbAgent.LastConnectedAt.Time) > agentUpdateFrequency*2:
|
||||
// The connection died without updating the last connected.
|
||||
agent.Status = codersdk.WorkspaceAgentDisconnected
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
}
|
108
coderd/workspaceagents_test.go
Normal file
108
coderd/workspaceagents_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
)
|
||||
|
||||
func TestWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
daemonCloser.Close()
|
||||
|
||||
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = client.WorkspaceAgent(context.Background(), resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentListen(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
daemonCloser.Close()
|
||||
|
||||
agentClient := codersdk.New(client.URL)
|
||||
agentClient.SessionToken = authToken
|
||||
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = agentCloser.Close()
|
||||
})
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
})
|
||||
_, err = conn.Ping()
|
||||
require.NoError(t, err)
|
||||
}
|
@ -106,7 +106,7 @@ func convertWorkspaceBuild(workspaceBuild database.WorkspaceBuild, job codersdk.
|
||||
}
|
||||
}
|
||||
|
||||
func convertWorkspaceResource(resource database.WorkspaceResource, agent *codersdk.WorkspaceAgent) codersdk.WorkspaceResource {
|
||||
func convertWorkspaceResource(resource database.WorkspaceResource, agents []codersdk.WorkspaceAgent) codersdk.WorkspaceResource {
|
||||
return codersdk.WorkspaceResource{
|
||||
ID: resource.ID,
|
||||
CreatedAt: resource.CreatedAt,
|
||||
@ -115,6 +115,6 @@ func convertWorkspaceResource(resource database.WorkspaceResource, agent *coders
|
||||
Address: resource.Address,
|
||||
Type: resource.Type,
|
||||
Name: resource.Name,
|
||||
Agent: agent,
|
||||
Agents: agents,
|
||||
}
|
||||
}
|
||||
|
@ -91,10 +91,10 @@ func TestWorkspaceBuildResources(t *testing.T) {
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agent: &proto.Agent{
|
||||
Agents: []*proto.Agent{{
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
}},
|
||||
}, {
|
||||
Name: "another",
|
||||
Type: "example",
|
||||
@ -113,7 +113,7 @@ func TestWorkspaceBuildResources(t *testing.T) {
|
||||
require.Len(t, resources, 2)
|
||||
require.Equal(t, "some", resources[0].Name)
|
||||
require.Equal(t, "example", resources[0].Type)
|
||||
require.NotNil(t, resources[0].Agent)
|
||||
require.Len(t, resources[0].Agents, 1)
|
||||
})
|
||||
}
|
||||
|
||||
@ -138,10 +138,10 @@ func TestWorkspaceBuildLogs(t *testing.T) {
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agent: &proto.Agent{
|
||||
Agents: []*proto.Agent{{
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
}},
|
||||
}, {
|
||||
Name: "another",
|
||||
Type: "example",
|
||||
|
@ -32,11 +32,11 @@ func TestPostWorkspaceAuthAWSInstanceIdentity(t *testing.T) {
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "somename",
|
||||
Type: "someinstance",
|
||||
Agent: &proto.Agent{
|
||||
Agents: []*proto.Agent{{
|
||||
Auth: &proto.Agent_InstanceId{
|
||||
InstanceId: instanceID,
|
||||
},
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
@ -98,11 +98,11 @@ func TestPostWorkspaceAuthGoogleInstanceIdentity(t *testing.T) {
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "somename",
|
||||
Type: "someinstance",
|
||||
Agent: &proto.Agent{
|
||||
Agents: []*proto.Agent{{
|
||||
Auth: &proto.Agent_InstanceId{
|
||||
InstanceId: instanceID,
|
||||
},
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
@ -2,26 +2,16 @@ package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/render"
|
||||
"github.com/hashicorp/yamux"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
func (api *api) workspaceResource(rw http.ResponseWriter, r *http.Request) {
|
||||
@ -40,15 +30,18 @@ func (api *api) workspaceResource(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
var apiAgent *codersdk.WorkspaceAgent
|
||||
if workspaceResource.AgentID.Valid {
|
||||
agent, err := api.Database.GetWorkspaceAgentByResourceID(r.Context(), workspaceResource.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get provisioner job agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
agents, err := api.Database.GetWorkspaceAgentsByResourceIDs(r.Context(), []uuid.UUID{workspaceResource.ID})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: fmt.Sprintf("get provisioner job agents: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
apiAgents := make([]codersdk.WorkspaceAgent, 0)
|
||||
for _, agent := range agents {
|
||||
convertedAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
@ -56,239 +49,9 @@ func (api *api) workspaceResource(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
apiAgent = &convertedAgent
|
||||
apiAgents = append(apiAgents, convertedAgent)
|
||||
}
|
||||
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, convertWorkspaceResource(workspaceResource, apiAgent))
|
||||
}
|
||||
|
||||
func (api *api) workspaceResourceDial(rw http.ResponseWriter, r *http.Request) {
|
||||
api.websocketWaitMutex.Lock()
|
||||
api.websocketWaitGroup.Add(1)
|
||||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
|
||||
resource := httpmw.WorkspaceResourceParam(r)
|
||||
if !resource.AgentID.Valid {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "resource doesn't have an agent",
|
||||
})
|
||||
return
|
||||
}
|
||||
agent, err := api.Database.GetWorkspaceAgentByResourceID(r.Context(), resource.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("get provisioner job agent: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("accept websocket: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
|
||||
ChannelID: agent.ID.String(),
|
||||
Logger: api.Logger.Named("peerbroker-proxy-dial"),
|
||||
Pubsub: api.Pubsub,
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
|
||||
api.websocketWaitMutex.Lock()
|
||||
api.websocketWaitGroup.Add(1)
|
||||
api.websocketWaitMutex.Unlock()
|
||||
defer api.websocketWaitGroup.Done()
|
||||
|
||||
agent := httpmw.WorkspaceAgent(r)
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("accept websocket: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("accept websocket: %s", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
|
||||
ChannelID: agent.ID.String(),
|
||||
Pubsub: api.Pubsub,
|
||||
Logger: api.Logger.Named("peerbroker-proxy-listen"),
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
defer closer.Close()
|
||||
firstConnectedAt := agent.FirstConnectedAt
|
||||
if !firstConnectedAt.Valid {
|
||||
firstConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
lastConnectedAt := sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
disconnectedAt := agent.DisconnectedAt
|
||||
updateConnectionTimes := func() error {
|
||||
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
|
||||
ID: agent.ID,
|
||||
FirstConnectedAt: firstConnectedAt,
|
||||
LastConnectedAt: lastConnectedAt,
|
||||
DisconnectedAt: disconnectedAt,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
// Ensure the resource is still valid!
|
||||
// We only accept agents for resources on the latest build.
|
||||
ensureLatestBuild := func() error {
|
||||
latestBuild, err := api.Database.GetWorkspaceBuildByWorkspaceIDWithoutAfter(r.Context(), build.WorkspaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if build.ID.String() != latestBuild.ID.String() {
|
||||
return xerrors.New("build is outdated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
disconnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
_ = updateConnectionTimes()
|
||||
}()
|
||||
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = ensureLatestBuild()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
}
|
||||
|
||||
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
|
||||
|
||||
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-session.CloseChan():
|
||||
return
|
||||
case <-ticker.C:
|
||||
lastConnectedAt = sql.NullTime{
|
||||
Time: database.Now(),
|
||||
Valid: true,
|
||||
}
|
||||
err = updateConnectionTimes()
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = ensureLatestBuild()
|
||||
if err != nil {
|
||||
// Disconnect agents that are no longer valid.
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
|
||||
var envs map[string]string
|
||||
if dbAgent.EnvironmentVariables.Valid {
|
||||
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
}
|
||||
agent := codersdk.WorkspaceAgent{
|
||||
ID: dbAgent.ID,
|
||||
CreatedAt: dbAgent.CreatedAt,
|
||||
UpdatedAt: dbAgent.UpdatedAt,
|
||||
ResourceID: dbAgent.ResourceID,
|
||||
InstanceID: dbAgent.AuthInstanceID.String,
|
||||
StartupScript: dbAgent.StartupScript.String,
|
||||
EnvironmentVariables: envs,
|
||||
}
|
||||
if dbAgent.FirstConnectedAt.Valid {
|
||||
agent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
|
||||
}
|
||||
if dbAgent.LastConnectedAt.Valid {
|
||||
agent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
|
||||
}
|
||||
if dbAgent.DisconnectedAt.Valid {
|
||||
agent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
|
||||
}
|
||||
switch {
|
||||
case !dbAgent.FirstConnectedAt.Valid:
|
||||
// If the agent never connected, it's waiting for the compute
|
||||
// to start up.
|
||||
agent.Status = codersdk.WorkspaceAgentWaiting
|
||||
case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time):
|
||||
// If we've disconnected after our last connection, we know the
|
||||
// agent is no longer connected.
|
||||
agent.Status = codersdk.WorkspaceAgentDisconnected
|
||||
case agentUpdateFrequency*2 >= database.Now().Sub(dbAgent.LastConnectedAt.Time):
|
||||
// The connection updated it's timestamp within the update frequency.
|
||||
// We multiply by two to allow for some lag.
|
||||
agent.Status = codersdk.WorkspaceAgentConnected
|
||||
case database.Now().Sub(dbAgent.LastConnectedAt.Time) > agentUpdateFrequency*2:
|
||||
// The connection died without updating the last connected.
|
||||
agent.Status = codersdk.WorkspaceAgentDisconnected
|
||||
}
|
||||
|
||||
return agent, nil
|
||||
render.JSON(rw, r, convertWorkspaceResource(workspaceResource, apiAgents))
|
||||
}
|
||||
|
@ -4,16 +4,10 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
)
|
||||
@ -33,10 +27,10 @@ func TestWorkspaceResource(t *testing.T) {
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agent: &proto.Agent{
|
||||
Agents: []*proto.Agent{{
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
@ -52,55 +46,3 @@ func TestWorkspaceResource(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceAgentListen(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agent: &proto.Agent{
|
||||
Id: uuid.NewString(),
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, codersdk.Me, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
daemonCloser.Close()
|
||||
|
||||
agentClient := codersdk.New(client.URL)
|
||||
agentClient.SessionToken = authToken
|
||||
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = agentCloser.Close()
|
||||
})
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].ID, nil, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
})
|
||||
_, err = conn.Ping()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user