feat: Add Tailscale networking (#3505)

* fix: Add coder user to docker group on installation

This makes for a simpler setup, and reduces the likelihood
a user runs into a strange issue.

* Add wgnet

* Add ping

* Add listening

* Finish refactor to make this work

* Add interface for swapping

* Fix conncache with interface

* chore: update gvisor

* fix tailscale types

* linting

* more linting

* Add coordinator

* Add coordinator tests

* Fix coordination

* It compiles!

* Move all connection negotiation in-memory

* Migrate coordinator to use net.conn

* Add closed func

* Fix close listener func

* Make reconnecting PTY work

* Fix reconnecting PTY

* Update CI to Go 1.19

* Add CLI flags for DERP mapping

* Fix Tailnet test

* Rename ConnCoordinator to TailnetCoordinator

* Remove print statement from workspace agent test

* Refactor wsconncache to use tailnet

* Remove STUN from unit tests

* Add migrate back to dump

* chore: Upgrade to Go 1.19

This is required as part of #3505.

* Fix reconnecting PTY tests

* fix: update wireguard-go to fix devtunnel

* fix migration numbers

* linting

* Return early for status if endpoints are empty

* Update cli/server.go

Co-authored-by: Colin Adler <colin1adler@gmail.com>

* Update cli/server.go

Co-authored-by: Colin Adler <colin1adler@gmail.com>

* Fix frontend entites

* Fix agent bicopy

* Fix race condition for the last node

* Fix down migration

* Fix connection RBAC

* Fix migration numbers

* Fix forwarding TCP to a local port

* Implement ping for tailnet

* Rename to ForceHTTP

* Add external derpmapping

* Expose DERP region names to the API

* Add global option to enable Tailscale networking for web

* Mark DERP flags hidden while testing

* Update DERP map on reconnect

* Add close func to workspace agents

* Fix race condition in upstream dependency

* Fix feature columns race condition

Co-authored-by: Colin Adler <colin1adler@gmail.com>
This commit is contained in:
Kyle Carberry
2022-08-31 20:09:44 -05:00
committed by GitHub
parent 00da01fdf7
commit 9bd83e5ec7
56 changed files with 2498 additions and 1817 deletions

View File

@ -18,6 +18,10 @@ import (
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"cdr.dev/slog"
"github.com/coder/coder/buildinfo"
@ -33,6 +37,7 @@ import (
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
"github.com/coder/coder/tailnet"
)
// Options are requires parameters for Coder to start.
@ -67,6 +72,10 @@ type Options struct {
AutoImportTemplates []AutoImportTemplate
LicenseHandler http.Handler
FeaturesService FeaturesService
TailscaleEnable bool
TailnetCoordinator *tailnet.Coordinator
DERPMap *tailcfg.DERPMap
}
// New constructs a Coder API handler.
@ -93,6 +102,9 @@ func New(options *Options) *API {
if options.PrometheusRegistry == nil {
options.PrometheusRegistry = prometheus.NewRegistry()
}
if options.TailnetCoordinator == nil {
options.TailnetCoordinator = tailnet.NewCoordinator()
}
if options.LicenseHandler == nil {
options.LicenseHandler = licenses()
}
@ -119,7 +131,12 @@ func New(options *Options) *API {
Logger: options.Logger,
},
}
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
if options.TailscaleEnable {
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
} else {
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
}
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
OIDC: options.OIDCConfig,
@ -148,6 +165,7 @@ func New(options *Options) *API {
// other applications might not as well.
r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
r.Get("/derp", derphttp.Handler(api.derpServer).ServeHTTP)
r.Route("/api/v2", func(r chi.Router) {
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
@ -338,9 +356,8 @@ func New(options *Options) *API {
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/wireguardlisten", api.workspaceAgentWireguardListener)
r.Post("/keys", api.postWorkspaceAgentKeys)
r.Get("/derp", api.derpMap)
r.Get("/coordinate", api.workspaceAgentCoordinate)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
@ -349,12 +366,13 @@ func New(options *Options) *API {
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceAgent)
r.Post("/peer", api.postWorkspaceAgentWireguardPeer)
r.Get("/dial", api.workspaceAgentDial)
r.Get("/turn", api.userWorkspaceAgentTurn)
r.Get("/pty", api.workspaceAgentPTY)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/derp", api.derpMap)
r.Get("/connection", api.workspaceAgentConnection)
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
})
})
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
@ -420,6 +438,8 @@ func New(options *Options) *API {
type API struct {
*Options
derpServer *derp.Server
Handler chi.Router
siteHandler http.Handler
websocketWaitMutex sync.Mutex

View File

@ -2,13 +2,21 @@ package coderd_test
import (
"context"
"net/netip"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
)
@ -38,3 +46,71 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
a.Test(ctx, assertRoute, skipRoutes)
}
func TestDERP(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpPort, err := strconv.Atoi(client.URL.Port())
require.NoError(t, err)
derpMap := &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "cdr",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
HostName: client.URL.Hostname(),
DERPPort: derpPort,
STUNPort: -1,
ForceHTTP: true,
}},
},
},
}
w1IP := tailnet.IP()
w1, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)},
Logger: logger.Named("w1"),
DERPMap: derpMap,
})
require.NoError(t, err)
w2, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
Logger: logger.Named("w2"),
DERPMap: derpMap,
})
require.NoError(t, err)
w1.SetNodeCallback(func(node *tailnet.Node) {
w2.UpdateNodes([]*tailnet.Node{node})
})
w2.SetNodeCallback(func(node *tailnet.Node) {
w1.UpdateNodes([]*tailnet.Node{node})
})
conn := make(chan struct{})
go func() {
listener, err := w1.Listen("tcp", ":35565")
assert.NoError(t, err)
defer listener.Close()
conn <- struct{}{}
nc, err := listener.Accept()
assert.NoError(t, err)
_ = nc.Close()
conn <- struct{}{}
}()
<-conn
nc, err := w2.DialContextTCP(context.Background(), netip.AddrPortFrom(w1IP, 35565))
require.NoError(t, err)
_ = nc.Close()
<-conn
w1.Close()
w2.Close()
}

View File

@ -167,6 +167,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
// skipRoutes allows skipping routes from being checked.
skipRoutes := map[string]string{
"POST:/api/v2/users/logout": "Logging out deletes the API Key for other routes",
"GET:/derp": "This requires a WebSocket upgrade!",
}
assertRoute := map[string]RouteCheck{
@ -193,12 +194,9 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/derp": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/wireguardlisten": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/keys": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/derp": {NoAuthorize: true},
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)},
@ -271,6 +269,10 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/workspaceagents/{workspaceagent}/coordinate": {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/workspaces/": {
StatusCode: http.StatusOK,
AssertAction: rbac.ActionRead,

View File

@ -21,6 +21,7 @@ import (
"net/http/httptest"
"net/url"
"os"
"strconv"
"strings"
"testing"
"time"
@ -36,6 +37,7 @@ import (
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
@ -178,6 +180,9 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
serverURL, err := url.Parse(srv.URL)
require.NoError(t, err)
derpPort, err := strconv.Atoi(serverURL.Port())
require.NoError(t, err)
// match default with cli default
if options.SSHKeygenAlgorithm == "" {
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
@ -211,7 +216,25 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),
AutoImportTemplates: options.AutoImportTemplates,
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "coder",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
IPv4: "127.0.0.1",
DERPPort: derpPort,
STUNPort: -1,
InsecureForTests: true,
ForceHTTP: true,
}},
},
},
},
AutoImportTemplates: options.AutoImportTemplates,
})
t.Cleanup(func() {
_ = coderAPI.Close()

View File

@ -1701,23 +1701,20 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser
defer q.mutex.Unlock()
agent := database.WorkspaceAgent{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
ResourceID: arg.ResourceID,
AuthToken: arg.AuthToken,
AuthInstanceID: arg.AuthInstanceID,
EnvironmentVariables: arg.EnvironmentVariables,
Name: arg.Name,
Architecture: arg.Architecture,
OperatingSystem: arg.OperatingSystem,
Directory: arg.Directory,
StartupScript: arg.StartupScript,
InstanceMetadata: arg.InstanceMetadata,
ResourceMetadata: arg.ResourceMetadata,
WireguardNodeIPv6: arg.WireguardNodeIPv6,
WireguardNodePublicKey: arg.WireguardNodePublicKey,
WireguardDiscoPublicKey: arg.WireguardDiscoPublicKey,
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
ResourceID: arg.ResourceID,
AuthToken: arg.AuthToken,
AuthInstanceID: arg.AuthInstanceID,
EnvironmentVariables: arg.EnvironmentVariables,
Name: arg.Name,
Architecture: arg.Architecture,
OperatingSystem: arg.OperatingSystem,
Directory: arg.Directory,
StartupScript: arg.StartupScript,
InstanceMetadata: arg.InstanceMetadata,
ResourceMetadata: arg.ResourceMetadata,
}
q.provisionerJobAgents = append(q.provisionerJobAgents, agent)
@ -2029,24 +2026,6 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg
return sql.ErrNoRows
}
func (q *fakeQuerier) UpdateWorkspaceAgentKeysByID(_ context.Context, arg database.UpdateWorkspaceAgentKeysByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for index, agent := range q.provisionerJobAgents {
if agent.ID != arg.ID {
continue
}
agent.WireguardNodePublicKey = arg.WireguardNodePublicKey
agent.WireguardDiscoPublicKey = arg.WireguardDiscoPublicKey
agent.UpdatedAt = arg.UpdatedAt
q.provisionerJobAgents[index] = agent
return nil
}
return sql.ErrNoRows
}
func (q *fakeQuerier) UpdateWorkspaceAgentVersionByID(_ context.Context, arg database.UpdateWorkspaceAgentVersionByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -1,74 +0,0 @@
package dbtypes
import (
"database/sql/driver"
"golang.org/x/xerrors"
"tailscale.com/types/key"
)
// NodePublic is a wrapper around a key.NodePublic which represents the
// Wireguard public key for an agent..
type NodePublic key.NodePublic
func (n NodePublic) String() string {
return key.NodePublic(n).String()
}
// This is necessary so NodePublic can be serialized in JSON loggers.
func (n NodePublic) MarshalJSON() ([]byte, error) {
j, err := key.NodePublic(n).MarshalText()
// surround in quotes to make it a JSON string
j = append([]byte{'"'}, append(j, '"')...)
return j, err
}
// Value is so NodePublic can be inserted into the database.
func (n NodePublic) Value() (driver.Value, error) {
return key.NodePublic(n).MarshalText()
}
// Scan is so NodePublic can be read from the database.
func (n *NodePublic) Scan(value interface{}) error {
switch v := value.(type) {
case []byte:
return (*key.NodePublic)(n).UnmarshalText(v)
case string:
return (*key.NodePublic)(n).UnmarshalText([]byte(v))
default:
return xerrors.Errorf("unexpected type: %T", v)
}
}
// NodePublic is a wrapper around a key.NodePublic which represents the
// Tailscale disco key for an agent.
type DiscoPublic key.DiscoPublic
func (n DiscoPublic) String() string {
return key.DiscoPublic(n).String()
}
// This is necessary so DiscoPublic can be serialized in JSON loggers.
func (n DiscoPublic) MarshalJSON() ([]byte, error) {
j, err := key.DiscoPublic(n).MarshalText()
// surround in quotes to make it a JSON string
j = append([]byte{'"'}, append(j, '"')...)
return j, err
}
// Value is so DiscoPublic can be inserted into the database.
func (n DiscoPublic) Value() (driver.Value, error) {
return key.DiscoPublic(n).MarshalText()
}
// Scan is so DiscoPublic can be read from the database.
func (n *DiscoPublic) Scan(value interface{}) error {
switch v := value.(type) {
case []byte:
return (*key.DiscoPublic)(n).UnmarshalText(v)
case string:
return (*key.DiscoPublic)(n).UnmarshalText([]byte(v))
default:
return xerrors.Errorf("unexpected type: %T", v)
}
}

View File

@ -309,9 +309,6 @@ CREATE TABLE workspace_agents (
instance_metadata jsonb,
resource_metadata jsonb,
directory character varying(4096) DEFAULT ''::character varying NOT NULL,
wireguard_node_ipv6 inet DEFAULT '::'::inet NOT NULL,
wireguard_node_public_key character varying(128) DEFAULT 'nodekey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL,
wireguard_disco_public_key character varying(128) DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL,
version text DEFAULT ''::text NOT NULL
);

View File

@ -0,0 +1,4 @@
ALTER TABLE workspace_agents
ADD COLUMN wireguard_node_ipv6 inet NOT NULL DEFAULT '::/128',
ADD COLUMN wireguard_node_public_key varchar(128) NOT NULL DEFAULT 'nodekey:0000000000000000000000000000000000000000000000000000000000000000',
ADD COLUMN wireguard_disco_public_key varchar(128) NOT NULL DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000';

View File

@ -0,0 +1,3 @@
ALTER TABLE workspace_agents DROP COLUMN wireguard_node_ipv6;
ALTER TABLE workspace_agents DROP COLUMN wireguard_node_public_key;
ALTER TABLE workspace_agents DROP COLUMN wireguard_disco_public_key;

View File

@ -10,7 +10,6 @@ import (
"fmt"
"time"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/google/uuid"
"github.com/tabbed/pqtype"
)
@ -519,26 +518,23 @@ type Workspace struct {
}
type WorkspaceAgent struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
Directory string `db:"directory" json:"directory"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
Directory string `db:"directory" json:"directory"`
// Version tracks the version of the currently running workspace agent. Workspace agents register their version upon start.
Version string `db:"version" json:"version"`
}

View File

@ -143,7 +143,6 @@ type querier interface {
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (Workspace, error)
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error
UpdateWorkspaceAgentVersionByID(ctx context.Context, arg UpdateWorkspaceAgentVersionByIDParams) error
UpdateWorkspaceAutostart(ctx context.Context, arg UpdateWorkspaceAutostartParams) error
UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWorkspaceBuildByIDParams) error

View File

@ -10,7 +10,6 @@ import (
"encoding/json"
"time"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/tabbed/pqtype"
@ -3173,7 +3172,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP
const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version
FROM
workspace_agents
WHERE
@ -3203,9 +3202,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
&i.Version,
)
return i, err
@ -3213,7 +3209,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
const getWorkspaceAgentByID = `-- name: GetWorkspaceAgentByID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version
FROM
workspace_agents
WHERE
@ -3241,9 +3237,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
&i.Version,
)
return i, err
@ -3251,7 +3244,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W
const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version
FROM
workspace_agents
WHERE
@ -3281,9 +3274,6 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
&i.Version,
)
return i, err
@ -3291,7 +3281,7 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
const getWorkspaceAgentsByResourceIDs = `-- name: GetWorkspaceAgentsByResourceIDs :many
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version
FROM
workspace_agents
WHERE
@ -3325,9 +3315,6 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
&i.Version,
); err != nil {
return nil, err
@ -3344,7 +3331,7 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
}
const getWorkspaceAgentsCreatedAfter = `-- name: GetWorkspaceAgentsCreatedAfter :many
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version FROM workspace_agents WHERE created_at > $1
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version FROM workspace_agents WHERE created_at > $1
`
func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceAgent, error) {
@ -3374,9 +3361,6 @@ func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, created
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
&i.Version,
); err != nil {
return nil, err
@ -3408,33 +3392,27 @@ INSERT INTO
startup_script,
directory,
instance_metadata,
resource_metadata,
wireguard_node_ipv6,
wireguard_node_public_key,
wireguard_disco_public_key
resource_metadata
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key, version
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version
`
type InsertWorkspaceAgentParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
Directory string `db:"directory" json:"directory"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
Directory string `db:"directory" json:"directory"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
}
func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) {
@ -3453,9 +3431,6 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
arg.Directory,
arg.InstanceMetadata,
arg.ResourceMetadata,
arg.WireguardNodeIPv6,
arg.WireguardNodePublicKey,
arg.WireguardDiscoPublicKey,
)
var i WorkspaceAgent
err := row.Scan(
@ -3476,9 +3451,6 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
&i.Version,
)
return i, err
@ -3515,34 +3487,6 @@ func (q *sqlQuerier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg
return err
}
const updateWorkspaceAgentKeysByID = `-- name: UpdateWorkspaceAgentKeysByID :exec
UPDATE
workspace_agents
SET
wireguard_node_public_key = $2,
wireguard_disco_public_key = $3,
updated_at = $4
WHERE
id = $1
`
type UpdateWorkspaceAgentKeysByIDParams struct {
ID uuid.UUID `db:"id" json:"id"`
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error {
_, err := q.db.ExecContext(ctx, updateWorkspaceAgentKeysByID,
arg.ID,
arg.WireguardNodePublicKey,
arg.WireguardDiscoPublicKey,
arg.UpdatedAt,
)
return err
}
const updateWorkspaceAgentVersionByID = `-- name: UpdateWorkspaceAgentVersionByID :exec
UPDATE
workspace_agents

View File

@ -53,13 +53,10 @@ INSERT INTO
startup_script,
directory,
instance_metadata,
resource_metadata,
wireguard_node_ipv6,
wireguard_node_public_key,
wireguard_disco_public_key
resource_metadata
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING *;
-- name: UpdateWorkspaceAgentConnectionByID :exec
UPDATE
@ -72,16 +69,6 @@ SET
WHERE
id = $1;
-- name: UpdateWorkspaceAgentKeysByID :exec
UPDATE
workspace_agents
SET
wireguard_node_public_key = $2,
wireguard_disco_public_key = $3,
updated_at = $4
WHERE
id = $1;
-- name: UpdateWorkspaceAgentVersionByID :exec
UPDATE
workspace_agents

View File

@ -16,12 +16,6 @@ packages:
# deleted after generation.
output_db_file_name: db_tmp.go
overrides:
- column: workspace_agents.wireguard_node_public_key
go_type: github.com/coder/coder/coderd/database/dbtypes.NodePublic
- column: workspace_agents.wireguard_disco_public_key
go_type: github.com/coder/coder/coderd/database/dbtypes.DiscoPublic
rename:
api_key: APIKey
login_type_oidc: LoginTypeOIDC
@ -34,5 +28,5 @@ rename:
gitsshkey: GitSSHKey
rbac_roles: RBACRoles
ip_address: IPAddress
wireguard_node_ipv6: WireguardNodeIPv6
ip_addresses: IPAddresses
jwt: JWT

View File

@ -23,13 +23,11 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
@ -804,9 +802,6 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
String: prAgent.StartupScript,
Valid: prAgent.StartupScript != "",
},
WireguardNodeIPv6: peerwg.UUIDToInet(agentID),
WireguardNodePublicKey: dbtypes.NodePublic{},
WireguardDiscoPublicKey: dbtypes.DiscoPublic{},
})
if err != nil {
return xerrors.Errorf("insert agent: %w", err)

View File

@ -264,7 +264,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
}
}
apiAgent, err := convertWorkspaceAgent(agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading job agent.",

View File

@ -8,22 +8,21 @@ import (
"io"
"net"
"net/http"
"net/netip"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"github.com/tabbed/pqtype"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
"inet.af/netaddr"
"nhooyr.io/websocket"
"tailscale.com/types/key"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
@ -31,10 +30,10 @@ import (
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/tailnet"
)
func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
@ -52,7 +51,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
})
return
}
apiAgent, err := convertWorkspaceAgent(workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -76,7 +75,7 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
httpapi.ResourceNotFound(rw)
return
}
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -127,7 +126,7 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -136,17 +135,8 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
return
}
ipp, ok := netaddr.FromStdIPNet(&workspaceAgent.WireguardNodeIPv6.IPNet)
if !ok {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Workspace agent has an invalid ipv6 address.",
Detail: workspaceAgent.WireguardNodeIPv6.IPNet.String(),
})
return
}
httpapi.Write(rw, http.StatusOK, agent.Metadata{
WireguardAddresses: []netaddr.IPPrefix{ipp},
DERPMap: api.DERPMap,
EnvironmentVariables: apiAgent.EnvironmentVariables,
StartupScript: apiAgent.StartupScript,
Directory: apiAgent.Directory,
@ -155,7 +145,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -431,7 +421,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
httpapi.ResourceNotFound(rw)
return
}
apiAgent, err := convertWorkspaceAgent(workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -498,141 +488,10 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(ptNetConn, wsNetConn)
}
func (*API) derpMap(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusOK, peerwg.DerpMap)
}
type WorkspaceKeysRequest struct {
Public key.NodePublic `json:"public"`
Disco key.DiscoPublic `json:"disco"`
}
func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
workspaceAgent = httpmw.WorkspaceAgent(r)
keys WorkspaceKeysRequest
)
if !httpapi.Read(rw, r, &keys) {
return
}
err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{
ID: workspaceAgent.ID,
WireguardNodePublicKey: dbtypes.NodePublic(keys.Public),
WireguardDiscoPublicKey: dbtypes.DiscoPublic(keys.Disco),
UpdatedAt: database.Now(),
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting agent keys.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) postWorkspaceAgentWireguardPeer(rw http.ResponseWriter, r *http.Request) {
var (
req peerwg.Handshake
workspaceAgent = httpmw.WorkspaceAgentParam(r)
workspace = httpmw.WorkspaceParam(r)
)
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
httpapi.ResourceNotFound(rw)
return
}
if !httpapi.Read(rw, r, &req) {
return
}
if req.Recipient != workspaceAgent.ID {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid recipient.",
})
return
}
raw, err := req.MarshalText()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error marshaling wireguard peer message.",
Detail: err.Error(),
})
return
}
err = api.Pubsub.Publish("wireguard_peers", raw)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error publishing wireguard peer message.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) workspaceAgentWireguardListener(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer conn.Close(websocket.StatusNormalClosure, "")
agentIDBytes, _ := workspaceAgent.ID.MarshalText()
subCancel, err := api.Pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) {
// Since we subscribe to all peer broadcasts, we do a light check to
// make sure we're the intended recipient without fully decoding the
// message.
hint, err := peerwg.HandshakeRecipientHint(agentIDBytes, message)
if err != nil {
api.Logger.Error(ctx, "invalid wireguard peer message", slog.Error(err))
return
}
// We aren't the intended recipient.
if !hint {
return
}
_ = conn.Write(ctx, websocket.MessageBinary, message)
})
if err != nil {
api.Logger.Error(ctx, "pubsub listen", slog.Error(err))
return
}
defer subCancel()
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, 200)
// Wait for the connection to close or the client to send a message.
//nolint:dogsled
_, _, _ = conn.Reader(ctx)
}
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
// r.Context() for cancellation if it's use is safe or r.Hijack() has
// not been performed.
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
client, server := provisionersdk.TransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
@ -691,12 +550,110 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
<-peerConn.Closed()
cancelFunc()
}()
return &agent.Conn{
return &agent.WebRTCConn{
Negotiator: peerClient,
Conn: peerConn,
}, nil
}
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
clientConn, serverConn := net.Pipe()
go func() {
<-r.Context().Done()
_ = clientConn.Close()
_ = serverConn.Close()
}()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: api.DERPMap,
Logger: api.Logger.Named("tailnet").Leveled(slog.LevelDebug),
})
if err != nil {
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
sendNodes, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
})
conn.SetNodeCallback(sendNodes)
go func() {
err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
_ = conn.Close()
}
}()
return &agent.TailnetConn{
Conn: conn,
}, nil
}
func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request) {
workspace := httpmw.WorkspaceParam(r)
if !api.Authorize(r, rbac.ActionRead, workspace) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(rw, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{
DERPMap: api.DERPMap,
})
}
func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer conn.Close(websocket.StatusNormalClosure, "")
err = api.TailnetCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
}
// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates.
// After accept a PubSub starts listening for new connection node updates
// which are written to the WebSocket.
func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.Request) {
workspace := httpmw.WorkspaceParam(r)
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
httpapi.ResourceNotFound(rw)
return
}
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer conn.Close(websocket.StatusNormalClosure, "")
err = api.TailnetCoordinator.ServeClient(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
}
func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
apps := make([]codersdk.WorkspaceApp, 0)
for _, dbApp := range dbApps {
@ -710,28 +667,14 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
return apps
}
func inetToNetaddr(inet pqtype.Inet) netaddr.IPPrefix {
if !inet.Valid {
return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128)
}
ipp, ok := netaddr.FromStdIPNet(&inet.IPNet)
if !ok {
return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128)
}
return ipp
}
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal env vars: %w", err)
}
}
workspaceAgent := codersdk.WorkspaceAgent{
ID: dbAgent.ID,
CreatedAt: dbAgent.CreatedAt,
@ -742,13 +685,29 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
Architecture: dbAgent.Architecture,
OperatingSystem: dbAgent.OperatingSystem,
StartupScript: dbAgent.StartupScript.String,
Version: dbAgent.Version,
EnvironmentVariables: envs,
Directory: dbAgent.Directory,
Apps: apps,
IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6),
WireguardPublicKey: key.NodePublic(dbAgent.WireguardNodePublicKey),
DiscoPublicKey: key.DiscoPublic(dbAgent.WireguardDiscoPublicKey),
Version: dbAgent.Version,
}
node := coordinator.Node(dbAgent.ID)
if node != nil {
workspaceAgent.DERPLatency = map[string]codersdk.DERPRegion{}
for rawRegion, latency := range node.DERPLatency {
regionParts := strings.SplitN(rawRegion, "-", 2)
regionID, err := strconv.Atoi(regionParts[0])
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("convert derp region id %q: %w", rawRegion, err)
}
region, found := derpMap.Regions[regionID]
if !found {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("region %d not found in derpmap", regionID)
}
workspaceAgent.DERPLatency[region.RegionName] = codersdk.DERPRegion{
Preferred: node.PreferredDERP == regionID,
LatencyMilliseconds: latency * 1000,
}
}
}
if dbAgent.FirstConnectedAt.Valid {

View File

@ -109,8 +109,11 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()
@ -199,7 +202,7 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
_, _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil))
_, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil))
require.Error(t, err)
require.ErrorContains(t, err, "build is outdated")
})
@ -240,8 +243,11 @@ func TestWorkspaceAgentTURN(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil),
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()
@ -265,6 +271,65 @@ func TestWorkspaceAgentTURN(t *testing.T) {
require.NoError(t, err)
}
func TestWorkspaceAgentTailnet(t *testing.T) {
t.Parallel()
client, daemonCloser := coderdtest.NewWithProvisionerCloser(t, nil)
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer agentCloser.Close()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), resources[0].Agents[0].ID)
require.NoError(t, err)
defer conn.Close()
sshClient, err := conn.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
output, err := session.CombinedOutput("echo test")
require.NoError(t, err)
_ = session.Close()
_ = sshClient.Close()
_ = conn.Close()
require.Equal(t, "test", strings.TrimSpace(string(output)))
}
func TestWorkspaceAgentPTY(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
@ -305,8 +370,11 @@ func TestWorkspaceAgentPTY(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil),
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()

View File

@ -81,8 +81,11 @@ func TestWorkspaceAppsProxyPath(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil),
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent"),
})
t.Cleanup(func() {
_ = agentCloser.Close()

View File

@ -70,7 +70,7 @@ func (api *API) workspaceResource(rw http.ResponseWriter, r *http.Request) {
}
}
convertedAgent, err := convertWorkspaceAgent(agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
convertedAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",

View File

@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
}
// Dialer creates a new agent connection by ID.
type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error)
type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error)
// Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct {
*agent.Conn
agent.Conn
locks atomic.Uint64
timeoutMutex sync.Mutex

View File

@ -7,13 +7,13 @@ import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
@ -23,10 +23,8 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
)
func TestMain(m *testing.M) {
@ -37,7 +35,7 @@ func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, 0)
defer func() {
@ -52,7 +50,7 @@ func TestCache(t *testing.T) {
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
called.Add(1)
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
@ -71,7 +69,7 @@ func TestCache(t *testing.T) {
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
@ -104,7 +102,7 @@ func TestCache(t *testing.T) {
}()
go server.Serve(random)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
@ -141,37 +139,48 @@ func TestCache(t *testing.T) {
})
}
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
client, server := provisionersdk.TransportPipe()
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
listener, err := peerbroker.Listen(server, func(ctx context.Context) ([]webrtc.ICEServer, *peer.ConnOptions, error) {
return nil, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
}, nil
})
return metadata, listener, err
}, &agent.Options{
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn {
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator()
agentID := uuid.New()
closer := agent.New(agent.Options{
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
return metadata, nil
},
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
})
go coordinator.ServeAgent(serverConn, agentID)
return clientConn, nil
},
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelInfo),
ReconnectingPTYTimeout: ptyTimeout,
})
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
_ = closer.Close()
})
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := api.NegotiateConnection(context.Background())
assert.NoError(t, err)
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: metadata.DERPMap,
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
})
return &agent.Conn{
Negotiator: api,
Conn: conn,
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node)
})
conn.SetNodeCallback(sendNode)
return &agent.TailnetConn{
Conn: conn,
}
}