feat: peer wireguard (#2445)

This commit is contained in:
Colin Adler
2022-06-24 10:25:01 -05:00
committed by GitHub
parent d21ab2115d
commit 05b67ab1cf
34 changed files with 1935 additions and 236 deletions

View File

@ -6,7 +6,6 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
)

View File

@ -307,6 +307,9 @@ func New(options *Options) *API {
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/wireguardlisten", api.workspaceAgentWireguardListener)
r.Post("/keys", api.postWorkspaceAgentKeys)
r.Get("/derp", api.derpMap)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
@ -315,10 +318,12 @@ func New(options *Options) *API {
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceAgent)
r.Post("/peer", api.postWorkspaceAgentWireguardPeer)
r.Get("/dial", api.workspaceAgentDial)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/pty", api.workspaceAgentPTY)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/derp", api.derpMap)
})
})
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {

View File

@ -154,8 +154,12 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
"GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/derp": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/wireguardlisten": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/keys": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/derp": {NoAuthorize: true},
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(admin.OrganizationID)},

View File

@ -25,17 +25,12 @@ import (
"testing"
"time"
"github.com/spf13/afero"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/util/ptr"
"cloud.google.com/go/compute/metadata"
"github.com/fullsailor/pkcs7"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/api/idtoken"
@ -50,7 +45,10 @@ import (
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/provisioner/echo"

View File

@ -1168,7 +1168,7 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- {
agent := q.provisionerJobAgents[i]
if agent.AuthToken.String() == authToken.String() {
if agent.AuthToken == authToken {
return agent, nil
}
}
@ -1182,7 +1182,7 @@ func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (da
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- {
agent := q.provisionerJobAgents[i]
if agent.ID.String() == id.String() {
if agent.ID == id {
return agent, nil
}
}
@ -1210,7 +1210,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc
workspaceAgents := make([]database.WorkspaceAgent, 0)
for _, agent := range q.provisionerJobAgents {
for _, resourceID := range resourceIDs {
if agent.ResourceID.String() != resourceID.String() {
if agent.ResourceID != resourceID {
continue
}
workspaceAgents = append(workspaceAgents, agent)
@ -1269,7 +1269,7 @@ func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (da
defer q.mutex.RUnlock()
for _, provisionerJob := range q.provisionerJobs {
if provisionerJob.ID.String() != id.String() {
if provisionerJob.ID != id {
continue
}
return provisionerJob, nil
@ -1604,23 +1604,26 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser
q.mutex.Lock()
defer q.mutex.Unlock()
//nolint:gosimple
agent := database.WorkspaceAgent{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
ResourceID: arg.ResourceID,
AuthToken: arg.AuthToken,
AuthInstanceID: arg.AuthInstanceID,
EnvironmentVariables: arg.EnvironmentVariables,
Name: arg.Name,
Architecture: arg.Architecture,
OperatingSystem: arg.OperatingSystem,
Directory: arg.Directory,
StartupScript: arg.StartupScript,
InstanceMetadata: arg.InstanceMetadata,
ResourceMetadata: arg.ResourceMetadata,
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
ResourceID: arg.ResourceID,
AuthToken: arg.AuthToken,
AuthInstanceID: arg.AuthInstanceID,
EnvironmentVariables: arg.EnvironmentVariables,
Name: arg.Name,
Architecture: arg.Architecture,
OperatingSystem: arg.OperatingSystem,
Directory: arg.Directory,
StartupScript: arg.StartupScript,
InstanceMetadata: arg.InstanceMetadata,
ResourceMetadata: arg.ResourceMetadata,
WireguardNodeIPv6: arg.WireguardNodeIPv6,
WireguardNodePublicKey: arg.WireguardNodePublicKey,
WireguardDiscoPublicKey: arg.WireguardDiscoPublicKey,
}
q.provisionerJobAgents = append(q.provisionerJobAgents, agent)
return agent, nil
}
@ -1874,7 +1877,7 @@ func (q *fakeQuerier) UpdateTemplateVersionDescriptionByJobID(_ context.Context,
continue
}
templateVersion.Readme = arg.Readme
templateVersion.UpdatedAt = time.Now()
templateVersion.UpdatedAt = database.Now()
q.templateVersions[index] = templateVersion
return nil
}
@ -1914,6 +1917,24 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg
return sql.ErrNoRows
}
func (q *fakeQuerier) UpdateWorkspaceAgentKeysByID(_ context.Context, arg database.UpdateWorkspaceAgentKeysByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for index, agent := range q.provisionerJobAgents {
if agent.ID != arg.ID {
continue
}
agent.WireguardNodePublicKey = arg.WireguardNodePublicKey
agent.WireguardDiscoPublicKey = arg.WireguardDiscoPublicKey
agent.UpdatedAt = database.Now()
q.provisionerJobAgents[index] = agent
return nil
}
return sql.ErrNoRows
}
func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -292,7 +292,10 @@ CREATE TABLE workspace_agents (
startup_script character varying(65534),
instance_metadata jsonb,
resource_metadata jsonb,
directory character varying(4096) DEFAULT ''::character varying NOT NULL
directory character varying(4096) DEFAULT ''::character varying NOT NULL,
wireguard_node_ipv6 inet DEFAULT '::'::inet NOT NULL,
wireguard_node_public_key character varying(128) DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL,
wireguard_disco_public_key character varying(128) DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL
);
CREATE TABLE workspace_apps (

View File

@ -0,0 +1,4 @@
ALTER TABLE workspace_agents
DROP COLUMN wireguard_node_ipv6,
DROP COLUMN wireguard_node_public_key,
DROP COLUMN wireguard_disco_public_key;

View File

@ -0,0 +1,4 @@
ALTER TABLE workspace_agents
ADD COLUMN wireguard_node_ipv6 inet NOT NULL DEFAULT '::/128',
ADD COLUMN wireguard_node_public_key varchar(128) NOT NULL DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000',
ADD COLUMN wireguard_disco_public_key varchar(128) NOT NULL DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000';

View File

@ -503,23 +503,26 @@ type Workspace struct {
}
type WorkspaceAgent struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
Directory string `db:"directory" json:"directory"`
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
Directory string `db:"directory" json:"directory"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
type WorkspaceApp struct {

View File

@ -28,7 +28,7 @@ type pgPubsub struct {
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
listeners map[string]map[string]Listener
listeners map[string]map[uuid.UUID]Listener
}
// Subscribe calls the listener when an event matching the name is received.
@ -45,20 +45,22 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
return nil, xerrors.Errorf("listen: %w", err)
}
var listeners map[string]Listener
var eventListeners map[uuid.UUID]Listener
var ok bool
if listeners, ok = p.listeners[event]; !ok {
listeners = map[string]Listener{}
p.listeners[event] = listeners
if eventListeners, ok = p.listeners[event]; !ok {
eventListeners = map[uuid.UUID]Listener{}
p.listeners[event] = eventListeners
}
var id string
var id uuid.UUID
for {
id = uuid.New().String()
if _, ok = listeners[id]; !ok {
id = uuid.New()
if _, ok = eventListeners[id]; !ok {
break
}
}
listeners[id] = listener
eventListeners[id] = listener
return func() {
p.mut.Lock()
defer p.mut.Unlock()
@ -77,7 +79,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
//nolint:gosec
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
return xerrors.Errorf("exec: %w", err)
return xerrors.Errorf("exec pg_notify: %w", err)
}
return nil
}
@ -128,7 +130,7 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) {
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
@ -150,7 +152,7 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
pgPubsub := &pgPubsub{
db: database,
pgListener: listener,
listeners: make(map[string]map[string]Listener),
listeners: make(map[string]map[uuid.UUID]Listener),
}
go pgPubsub.listen(ctx)

View File

@ -127,6 +127,7 @@ type querier interface {
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error
UpdateWorkspaceAutostart(ctx context.Context, arg UpdateWorkspaceAutostartParams) error
UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWorkspaceBuildByIDParams) error
UpdateWorkspaceDeletedByID(ctx context.Context, arg UpdateWorkspaceDeletedByIDParams) error

View File

@ -2850,7 +2850,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP
const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2880,13 +2880,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
const getWorkspaceAgentByID = `-- name: GetWorkspaceAgentByID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2914,13 +2917,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2950,13 +2956,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
const getWorkspaceAgentsByResourceIDs = `-- name: GetWorkspaceAgentsByResourceIDs :many
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2990,6 +2999,9 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
); err != nil {
return nil, err
}
@ -3005,7 +3017,7 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
}
const getWorkspaceAgentsCreatedAfter = `-- name: GetWorkspaceAgentsCreatedAfter :many
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory FROM workspace_agents WHERE created_at > $1
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key FROM workspace_agents WHERE created_at > $1
`
func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceAgent, error) {
@ -3035,6 +3047,9 @@ func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, created
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
); err != nil {
return nil, err
}
@ -3065,27 +3080,33 @@ INSERT INTO
startup_script,
directory,
instance_metadata,
resource_metadata
resource_metadata,
wireguard_node_ipv6,
wireguard_node_public_key,
wireguard_disco_public_key
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
`
type InsertWorkspaceAgentParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
Directory string `db:"directory" json:"directory"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
Directory string `db:"directory" json:"directory"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) {
@ -3104,6 +3125,9 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
arg.Directory,
arg.InstanceMetadata,
arg.ResourceMetadata,
arg.WireguardNodeIPv6,
arg.WireguardNodePublicKey,
arg.WireguardDiscoPublicKey,
)
var i WorkspaceAgent
err := row.Scan(
@ -3124,6 +3148,9 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
@ -3132,6 +3159,7 @@ const updateWorkspaceAgentConnectionByID = `-- name: UpdateWorkspaceAgentConnect
UPDATE
workspace_agents
SET
updated_at = now(),
first_connected_at = $2,
last_connected_at = $3,
disconnected_at = $4
@ -3156,6 +3184,28 @@ func (q *sqlQuerier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg
return err
}
const updateWorkspaceAgentKeysByID = `-- name: UpdateWorkspaceAgentKeysByID :exec
UPDATE
workspace_agents
SET
updated_at = now(),
wireguard_node_public_key = $2,
wireguard_disco_public_key = $3
WHERE
id = $1
`
type UpdateWorkspaceAgentKeysByIDParams struct {
ID uuid.UUID `db:"id" json:"id"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error {
_, err := q.db.ExecContext(ctx, updateWorkspaceAgentKeysByID, arg.ID, arg.WireguardNodePublicKey, arg.WireguardDiscoPublicKey)
return err
}
const getWorkspaceAppByAgentIDAndName = `-- name: GetWorkspaceAppByAgentIDAndName :one
SELECT id, created_at, agent_id, name, icon, command, url, relative_path FROM workspace_apps WHERE agent_id = $1 AND name = $2
`

View File

@ -53,17 +53,31 @@ INSERT INTO
startup_script,
directory,
instance_metadata,
resource_metadata
resource_metadata,
wireguard_node_ipv6,
wireguard_node_public_key,
wireguard_disco_public_key
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING *;
-- name: UpdateWorkspaceAgentConnectionByID :exec
UPDATE
workspace_agents
SET
updated_at = now(),
first_connected_at = $2,
last_connected_at = $3,
disconnected_at = $4
WHERE
id = $1;
-- name: UpdateWorkspaceAgentKeysByID :exec
UPDATE
workspace_agents
SET
updated_at = now(),
wireguard_node_public_key = $2,
wireguard_disco_public_key = $3
WHERE
id = $1;

View File

@ -15,9 +15,13 @@ packages:
# to add support for transactions. This file is
# deleted after generation.
output_db_file_name: db_tmp.go
overrides:
- db_type: citext
go_type: string
- column: workspaces.wireguard_public_key
go_type: tailscale.com/types/key.MachinePublic
- column: workspaces.disco_public_key
go_type: tailscale.com/types/key.DiscoPublic
rename:
api_key: APIKey
login_type_oidc: LoginTypeOIDC
@ -30,3 +34,4 @@ rename:
gitsshkey: GitSSHKey
rbac_roles: RBACRoles
ip_address: IPAddress
wireguard_node_ipv6: WireguardNodeIPv6

View File

@ -37,21 +37,20 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
}
token, err := uuid.Parse(cookie.Value)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("Parse token %q: %s.", cookie.Value, err),
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "Agent token is invalid.",
})
return
}
agent, err := db.GetWorkspaceAgentByAuthToken(r.Context(), token)
if errors.Is(err, sql.ErrNoRows) {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "Agent token is invalid.",
})
return
}
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching workspace agent.",
Detail: err.Error(),

View File

@ -31,6 +31,7 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl
if !parsed {
return
}
agent, err := db.GetWorkspaceAgentByID(r.Context(), agentUUID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
@ -45,6 +46,7 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl
})
return
}
resource, err := db.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{

View File

@ -19,6 +19,7 @@ import (
protobuf "google.golang.org/protobuf/proto"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/types/key"
"cdr.dev/slog"
@ -27,6 +28,7 @@ import (
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
@ -714,17 +716,17 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
snapshot.WorkspaceResources = append(snapshot.WorkspaceResources, telemetry.ConvertWorkspaceResource(resource))
for _, agent := range protoResource.Agents {
for _, prAgent := range protoResource.Agents {
var instanceID sql.NullString
if agent.GetInstanceId() != "" {
if prAgent.GetInstanceId() != "" {
instanceID = sql.NullString{
String: agent.GetInstanceId(),
String: prAgent.GetInstanceId(),
Valid: true,
}
}
var env pqtype.NullRawMessage
if agent.Env != nil {
data, err := json.Marshal(agent.Env)
if prAgent.Env != nil {
data, err := json.Marshal(prAgent.Env)
if err != nil {
return xerrors.Errorf("marshal env: %w", err)
}
@ -734,36 +736,40 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
}
authToken := uuid.New()
if agent.GetToken() != "" {
authToken, err = uuid.Parse(agent.GetToken())
if prAgent.GetToken() != "" {
authToken, err = uuid.Parse(prAgent.GetToken())
if err != nil {
return xerrors.Errorf("invalid auth token format; must be uuid: %w", err)
}
}
agentID := uuid.New()
dbAgent, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
ID: uuid.New(),
ID: agentID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
ResourceID: resource.ID,
Name: agent.Name,
Name: prAgent.Name,
AuthToken: authToken,
AuthInstanceID: instanceID,
Architecture: agent.Architecture,
Architecture: prAgent.Architecture,
EnvironmentVariables: env,
Directory: agent.Directory,
OperatingSystem: agent.OperatingSystem,
Directory: prAgent.Directory,
OperatingSystem: prAgent.OperatingSystem,
StartupScript: sql.NullString{
String: agent.StartupScript,
Valid: agent.StartupScript != "",
String: prAgent.StartupScript,
Valid: prAgent.StartupScript != "",
},
WireguardNodeIPv6: peerwg.UUIDToInet(agentID),
WireguardNodePublicKey: key.NodePublic{}.String(),
WireguardDiscoPublicKey: key.DiscoPublic{}.String(),
})
if err != nil {
return xerrors.Errorf("insert agent: %w", err)
}
snapshot.WorkspaceAgents = append(snapshot.WorkspaceAgents, telemetry.ConvertWorkspaceAgent(dbAgent))
for _, app := range agent.Apps {
for _, app := range prAgent.Apps {
dbApp, err := db.InsertWorkspaceApp(ctx, database.InsertWorkspaceAppParams{
ID: uuid.New(),
CreatedAt: database.Now(),

View File

@ -13,8 +13,11 @@ import (
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"github.com/tabbed/pqtype"
"golang.org/x/xerrors"
"inet.af/netaddr"
"nhooyr.io/websocket"
"tailscale.com/types/key"
"cdr.dev/slog"
"github.com/coder/coder/agent"
@ -25,6 +28,7 @@ import (
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
@ -156,7 +160,18 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
})
return
}
ipp, ok := netaddr.FromStdIPNet(&workspaceAgent.WireguardNodeIPv6.IPNet)
if !ok {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Workspace agent has an invalid ipv6 address.",
Detail: workspaceAgent.WireguardNodeIPv6.IPNet.String(),
})
return
}
httpapi.Write(rw, http.StatusOK, agent.Metadata{
WireguardAddresses: []netaddr.IPPrefix{ipp},
OwnerEmail: owner.Email,
OwnerUsername: owner.Username,
EnvironmentVariables: apiAgent.EnvironmentVariables,
@ -452,6 +467,133 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(ptNetConn, wsNetConn)
}
func (*API) derpMap(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusOK, peerwg.DerpMap)
}
type WorkspaceKeysRequest struct {
Public key.NodePublic `json:"public"`
Disco key.DiscoPublic `json:"disco"`
}
func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
workspaceAgent = httpmw.WorkspaceAgent(r)
keys WorkspaceKeysRequest
)
if !httpapi.Read(rw, r, &keys) {
return
}
err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{
ID: workspaceAgent.ID,
WireguardNodePublicKey: keys.Public.String(),
WireguardDiscoPublicKey: keys.Disco.String(),
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error setting agent keys.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) postWorkspaceAgentWireguardPeer(rw http.ResponseWriter, r *http.Request) {
var (
req peerwg.Handshake
workspaceAgent = httpmw.WorkspaceAgentParam(r)
workspace = httpmw.WorkspaceParam(r)
)
if !api.Authorize(r, rbac.ActionUpdate, workspace) {
httpapi.ResourceNotFound(rw)
return
}
if !httpapi.Read(rw, r, &req) {
return
}
if req.Recipient != workspaceAgent.ID {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Invalid recipient.",
})
return
}
raw, err := req.MarshalText()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error marshaling wireguard peer message.",
Detail: err.Error(),
})
return
}
err = api.Pubsub.Publish("wireguard_peers", raw)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error publishing wireguard peer message.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) workspaceAgentWireguardListener(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer conn.Close(websocket.StatusNormalClosure, "")
agentIDBytes, _ := workspaceAgent.ID.MarshalText()
subCancel, err := api.Pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) {
// Since we subscribe to all peer broadcasts, we do a light check to
// make sure we're the intended recipient without fully decoding the
// message.
hint, err := peerwg.HandshakeRecipientHint(agentIDBytes, message)
if err != nil {
api.Logger.Error(ctx, "invalid wireguard peer message", slog.Error(err))
return
}
// We aren't the intended recipient.
if !hint {
return
}
_ = conn.Write(ctx, websocket.MessageBinary, message)
})
if err != nil {
api.Logger.Error(ctx, "pubsub listen", slog.Error(err))
return
}
defer subCancel()
// Wait for the connection to close or the client to send a message.
//nolint:dogsled
_, _, _ = conn.Reader(ctx)
}
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
// r.Context() for cancellation if it's use is safe or r.Hijack() has
// not been performed.
@ -533,6 +675,19 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
return apps
}
func inetToNetaddr(inet pqtype.Inet) netaddr.IPPrefix {
if !inet.Valid {
return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128)
}
ipp, ok := netaddr.FromStdIPNet(&inet.IPNet)
if !ok {
return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128)
}
return ipp
}
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
@ -541,6 +696,7 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
}
}
workspaceAgent := codersdk.WorkspaceAgent{
ID: dbAgent.ID,
CreatedAt: dbAgent.CreatedAt,
@ -554,7 +710,18 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
EnvironmentVariables: envs,
Directory: dbAgent.Directory,
Apps: apps,
IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6),
}
err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardNodePublicKey))
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard node public key %q: %w", dbAgent.WireguardNodePublicKey, err)
}
err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.WireguardDiscoPublicKey))
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.WireguardDiscoPublicKey, err)
}
if dbAgent.FirstConnectedAt.Valid {
workspaceAgent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
}