chore: replace wsconncache with a single tailnet (#8176)

This commit is contained in:
Colin Adler
2023-07-12 17:37:31 -05:00
committed by GitHub
parent 0a37dd20d6
commit c47b78c44b
36 changed files with 2004 additions and 763 deletions

5
coderd/apidoc/docs.go generated
View File

@ -5961,6 +5961,9 @@ const docTemplate = `{
"agentsdk.Manifest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"apps": {
"type": "array",
"items": {
@ -7617,6 +7620,7 @@ const docTemplate = `{
"workspace_actions",
"tailnet_ha_coordinator",
"convert-to-oidc",
"single_tailnet",
"workspace_build_logs_ui"
],
"x-enum-varnames": [
@ -7624,6 +7628,7 @@ const docTemplate = `{
"ExperimentWorkspaceActions",
"ExperimentTailnetHACoordinator",
"ExperimentConvertToOIDC",
"ExperimentSingleTailnet",
"ExperimentWorkspaceBuildLogsUI"
]
},

View File

@ -5251,6 +5251,9 @@
"agentsdk.Manifest": {
"type": "object",
"properties": {
"agent_id": {
"type": "string"
},
"apps": {
"type": "array",
"items": {
@ -6818,6 +6821,7 @@
"workspace_actions",
"tailnet_ha_coordinator",
"convert-to-oidc",
"single_tailnet",
"workspace_build_logs_ui"
],
"x-enum-varnames": [
@ -6825,6 +6829,7 @@
"ExperimentWorkspaceActions",
"ExperimentTailnetHACoordinator",
"ExperimentConvertToOIDC",
"ExperimentSingleTailnet",
"ExperimentWorkspaceBuildLogsUI"
]
},

View File

@ -364,8 +364,23 @@ func New(options *Options) *API {
}
api.Auditor.Store(&options.Auditor)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
if api.Experiments.Enabled(codersdk.ExperimentSingleTailnet) {
api.agentProvider, err = NewServerTailnet(api.ctx,
options.Logger,
options.DERPServer,
options.DERPMap,
&api.TailnetCoordinator,
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
)
if err != nil {
panic("failed to setup server tailnet: " + err.Error())
}
} else {
api.agentProvider = &wsconncache.AgentProvider{
Cache: wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
}
}
api.workspaceAppServer = &workspaceapps.Server{
Logger: options.Logger.Named("workspaceapps"),
@ -377,7 +392,7 @@ func New(options *Options) *API {
RealIPConfig: options.RealIPConfig,
SignedTokenProvider: api.WorkspaceAppsProvider,
WorkspaceConnCache: api.workspaceAgentCache,
AgentProvider: api.agentProvider,
AppSecurityKey: options.AppSecurityKey,
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
@ -921,10 +936,10 @@ type API struct {
derpCloseFunc func()
metricsCache *metricscache.Cache
workspaceAgentCache *wsconncache.Cache
updateChecker *updatecheck.Checker
WorkspaceAppsProvider workspaceapps.SignedTokenProvider
workspaceAppServer *workspaceapps.Server
agentProvider workspaceapps.AgentProvider
// Experiments contains the list of experiments currently enabled.
// This is used to gate features that are not yet ready for production.
@ -951,7 +966,8 @@ func (api *API) Close() error {
if coordinator != nil {
_ = (*coordinator).Close()
}
return api.workspaceAgentCache.Close()
_ = api.agentProvider.Close()
return nil
}
func compressHandler(h http.Handler) http.Handler {

View File

@ -109,6 +109,7 @@ type Options struct {
GitAuthConfigs []*gitauth.Config
TrialGenerator func(context.Context, string) error
TemplateScheduleStore schedule.TemplateScheduleStore
Coordinator tailnet.Coordinator
HealthcheckFunc func(ctx context.Context, apiKey string) *healthcheck.Report
HealthcheckTimeout time.Duration

View File

@ -302,7 +302,7 @@ func TestAgents(t *testing.T) {
coordinator := tailnet.NewCoordinator(slogtest.Make(t, nil).Leveled(slog.LevelDebug))
coordinatorPtr := atomic.Pointer[tailnet.Coordinator]{}
coordinatorPtr.Store(&coordinator)
derpMap := tailnettest.RunDERPAndSTUN(t)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
agentInactiveDisconnectTimeout := 1 * time.Hour // don't need to focus on this value in tests
registry := prometheus.NewRegistry()

339
coderd/tailnet.go Normal file
View File

@ -0,0 +1,339 @@
package coderd
import (
"bufio"
"context"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"tailscale.com/derp"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
"github.com/coder/coder/tailnet"
)
var tailnetTransport *http.Transport
func init() {
var valid bool
tailnetTransport, valid = http.DefaultTransport.(*http.Transport)
if !valid {
panic("dev error: default transport is the wrong type")
}
}
// NewServerTailnet creates a new tailnet intended for use by coderd. It
// automatically falls back to wsconncache if a legacy agent is encountered.
func NewServerTailnet(
ctx context.Context,
logger slog.Logger,
derpServer *derp.Server,
derpMap *tailcfg.DERPMap,
coord *atomic.Pointer[tailnet.Coordinator],
cache *wsconncache.Cache,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: derpMap,
Logger: logger,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
serverCtx, cancel := context.WithCancel(ctx)
tn := &ServerTailnet{
ctx: serverCtx,
cancel: cancel,
logger: logger,
conn: conn,
coord: coord,
cache: cache,
agentNodes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
}
tn.transport.DialContext = tn.dialContext
tn.transport.MaxIdleConnsPerHost = 10
tn.transport.MaxIdleConns = 0
agentConn := (*coord.Load()).ServeMultiAgent(uuid.New())
tn.agentConn.Store(&agentConn)
err = tn.getAgentConn().UpdateSelf(conn.Node())
if err != nil {
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
}
conn.SetNodeCallback(func(node *tailnet.Node) {
err := tn.getAgentConn().UpdateSelf(node)
if err != nil {
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
}
})
// This is set to allow local DERP traffic to be proxied through memory
// instead of needing to hit the external access URL. Don't use the ctx
// given in this callback, it's only valid while connecting.
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
derpServer.Accept(ctx, right, brw, "internal")
}()
return left
})
go tn.watchAgentUpdates()
go tn.expireOldAgents()
return tn, nil
}
func (s *ServerTailnet) expireOldAgents() {
const (
tick = 5 * time.Minute
cutoff = 30 * time.Minute
)
ticker := time.NewTicker(tick)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
}
s.nodesMu.Lock()
agentConn := s.getAgentConn()
for agentID, lastConnection := range s.agentNodes {
// If no one has connected since the cutoff and there are no active
// connections, remove the agent.
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
_ = agentConn
// err := agentConn.UnsubscribeAgent(agentID)
// if err != nil {
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
// }
// delete(s.agentNodes, agentID)
// TODO(coadler): actually remove from the netmap, then reenable
// the above
}
}
s.nodesMu.Unlock()
}
}
func (s *ServerTailnet) watchAgentUpdates() {
for {
conn := s.getAgentConn()
nodes, ok := conn.NextUpdate(s.ctx)
if !ok {
if conn.IsClosed() && s.ctx.Err() == nil {
s.reinitCoordinator()
continue
}
return
}
err := s.conn.UpdateNodes(nodes, false)
if err != nil {
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
return
}
}
}
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
return *s.agentConn.Load()
}
func (s *ServerTailnet) reinitCoordinator() {
s.nodesMu.Lock()
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
s.agentConn.Store(&agentConn)
// Resubscribe to all of the agents we're tracking.
for agentID := range s.agentNodes {
err := agentConn.SubscribeAgent(agentID)
if err != nil {
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
}
}
s.nodesMu.Unlock()
}
type ServerTailnet struct {
ctx context.Context
cancel func()
logger slog.Logger
conn *tailnet.Conn
coord *atomic.Pointer[tailnet.Coordinator]
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
// agentNodes is a map of agent tailnetNodes the server wants to keep a
// connection to. It contains the last time the agent was connected to.
agentNodes map[uuid.UUID]time.Time
// agentTockets holds a map of all open connections to an agent.
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
transport *http.Transport
}
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: dashboardURL.String(),
})
}
proxy.Director = s.director(agentID, proxy.Director)
proxy.Transport = s.transport
return proxy, func() {}, nil
}
type agentIDKey struct{}
// director makes sure agentIDKey is set on the context in the reverse proxy.
// This allows the transport to correctly identify which agent to dial to.
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
return func(req *http.Request) {
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
*req = *req.WithContext(ctx)
prev(req)
}
}
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
if !ok {
return nil, xerrors.Errorf("no agent id attached")
}
return s.DialAgentNetConn(ctx, agentID, network, addr)
}
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
s.nodesMu.Lock()
defer s.nodesMu.Unlock()
_, ok := s.agentNodes[agentID]
// If we don't have the node, subscribe.
if !ok {
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
err := s.getAgentConn().SubscribeAgent(agentID)
if err != nil {
return xerrors.Errorf("subscribe agent: %w", err)
}
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
}
s.agentNodes[agentID] = time.Now()
return nil
}
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
var (
conn *codersdk.WorkspaceAgentConn
ret = func() {}
)
if s.getAgentConn().AgentIsLegacy(agentID) {
s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID))
cconn, release, err := s.cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err)
}
conn = cconn.WorkspaceAgentConn
ret = release
} else {
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return codersdk.ErrSkipClose },
})
}
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
reachable := conn.AwaitReachable(ctx)
if !reachable {
ret()
conn.Close()
return nil, nil, xerrors.New("agent is unreachable")
}
return conn, ret, nil
}
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
conn, release, err := s.AgentConn(ctx, agentID)
if err != nil {
return nil, xerrors.Errorf("acquire agent conn: %w", err)
}
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
nc, err := conn.DialContext(ctx, network, addr)
if err != nil {
release()
conn.Close()
return nil, xerrors.Errorf("dial context: %w", err)
}
return &netConnCloser{Conn: nc, close: func() {
release()
conn.Close()
}}, err
}
type netConnCloser struct {
net.Conn
close func()
}
func (c *netConnCloser) Close() error {
c.close()
return c.Conn.Close()
}
func (s *ServerTailnet) Close() error {
s.cancel()
_ = s.cache.Close()
_ = s.conn.Close()
s.transport.CloseIdleConnections()
return nil
}

207
coderd/tailnet_test.go Normal file
View File

@ -0,0 +1,207 @@
package coderd_test
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/agent/agenttest"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/codersdk/agentsdk"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/tailnet/tailnettest"
"github.com/coder/coder/testutil"
)
func TestServerTailnet_AgentConn_OK(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Connect through the ServerTailnet
agentID, _, serverTailnet := setupAgent(t, nil)
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
require.NoError(t, err)
defer release()
assert.True(t, conn.AwaitReachable(ctx))
}
func TestServerTailnet_AgentConn_Legacy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
})
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
require.NoError(t, err)
defer release()
assert.True(t, conn.AwaitReachable(ctx))
}
func TestServerTailnet_ReverseProxy_OK(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, nil)
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
require.NoError(t, err)
defer release()
rw := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
u.String(),
nil,
).WithContext(ctx)
rp.ServeHTTP(rw, req)
res := rw.Result()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestServerTailnet_ReverseProxy_Legacy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
})
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
require.NoError(t, err)
defer release()
rw := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
u.String(),
nil,
).WithContext(ctx)
rp.ServeHTTP(rw, req)
res := rw.Result()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
DERPMap: derpMap,
}
var coordPtr atomic.Pointer[tailnet.Coordinator]
coord := tailnet.NewCoordinator(logger)
coordPtr.Store(&coord)
t.Cleanup(func() {
_ = coord.Close()
})
c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
options := agent.Options{
Client: c,
Filesystem: afero.NewMemMapFs(),
Logger: logger.Named("agent"),
Addresses: agentAddresses,
}
ag := agent.New(options)
t.Cleanup(func() {
_ = ag.Close()
})
// Wait for the agent to connect.
require.Eventually(t, func() bool {
return coord.Node(manifest.AgentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: manifest.DERPMap,
Logger: logger.Named("client"),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
serveClientDone := make(chan struct{})
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
<-serveClientDone
})
go func() {
defer close(serveClientDone)
coord.ServeClient(serverConn, uuid.New(), manifest.AgentID)
}()
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error { return codersdk.ErrSkipClose },
}), nil
}, 0)
serverTailnet, err := coderd.NewServerTailnet(
context.Background(),
logger,
derpServer,
manifest.DERPMap,
&coordPtr,
cache,
)
require.NoError(t, err)
t.Cleanup(func() {
_ = serverTailnet.Close()
})
return manifest.AgentID, ag, serverTailnet
}

View File

@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
}
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{
AgentID: apiAgent.ID,
Apps: convertApps(dbApps),
DERPMap: api.DERPMap,
GitAuthConfigs: len(api.GitAuthConfigs),
@ -654,7 +655,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
return
}
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error dialing workspace agent.",
@ -729,7 +730,9 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
}
func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
// Deprecated: use api.tailnet.AgentConn instead.
// See: https://github.com/coder/coder/issues/8218
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
clientConn, serverConn := net.Pipe()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
@ -765,14 +768,16 @@ func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspac
return nil
})
conn.SetNodeCallback(sendNodes)
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error {
cancel()
_ = clientConn.Close()
_ = serverConn.Close()
return nil
},
}
})
go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {

View File

@ -399,7 +399,8 @@ func doWithRetries(t require.TestingT, client *codersdk.Client, req *http.Reques
return resp, err
}
func requestWithRetries(ctx context.Context, t require.TestingT, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
func requestWithRetries(ctx context.Context, t testing.TB, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
t.Helper()
var resp *http.Response
var err error
require.Eventually(t, func() bool {

View File

@ -23,7 +23,6 @@ import (
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/util/slice"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
)
@ -61,6 +60,22 @@ var nonCanonicalHeaders = map[string]string{
"Sec-Websocket-Version": "Sec-WebSocket-Version",
}
type AgentProvider interface {
// ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests
// to the specified agent.
//
// TODO: after wsconncache is deleted this doesn't need to return an error.
ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error)
// AgentConn returns a new connection to the specified agent.
//
// TODO: after wsconncache is deleted this doesn't need to return a release
// func.
AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error)
Close() error
}
// Server serves workspace apps endpoints, including:
// - Path-based apps
// - Subdomain app middleware
@ -83,7 +98,6 @@ type Server struct {
RealIPConfig *httpmw.RealIPConfig
SignedTokenProvider SignedTokenProvider
WorkspaceConnCache *wsconncache.Cache
AppSecurityKey SecurityKey
// DisablePathApps disables path-based apps. This is a security feature as path
@ -95,6 +109,8 @@ type Server struct {
DisablePathApps bool
SecureAuthCookie bool
AgentProvider AgentProvider
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
}
@ -106,8 +122,8 @@ func (s *Server) Close() error {
s.websocketWaitGroup.Wait()
s.websocketWaitMutex.Unlock()
// The caller must close the SignedTokenProvider (if necessary) and the
// wsconncache.
// The caller must close the SignedTokenProvider and the AgentProvider (if
// necessary).
return nil
}
@ -517,18 +533,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
r.URL.Path = path
appURL.RawQuery = ""
proxy := httputil.NewSingleHostReverseProxy(appURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: s.DashboardURL.String(),
})
}
conn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID)
if err != nil {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadGateway,
@ -540,7 +545,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
return
}
defer release()
proxy.Transport = conn.HTTPTransport()
proxy.ModifyResponse = func(r *http.Response) error {
r.Header.Del(httpmw.AccessControlAllowOriginHeader)
@ -658,13 +662,14 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
go httpapi.Heartbeat(ctx, conn)
agentConn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
if err != nil {
log.Debug(ctx, "dial workspace agent", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
}
defer release()
defer agentConn.Close()
log.Debug(ctx, "dialed workspace agent")
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"))
if err != nil {

View File

@ -1,9 +1,12 @@
// Package wsconncache caches workspace agent connections by UUID.
// Deprecated: Use ServerTailnet instead.
package wsconncache
import (
"context"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
@ -13,13 +16,57 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
)
// New creates a new workspace connection cache that closes
// connections after the inactive timeout provided.
type AgentProvider struct {
Cache *Cache
}
func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
conn, rel, err := a.Cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
}
return conn.WorkspaceAgentConn, rel, nil
}
func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: dashboardURL.String(),
})
}
conn, release, err := a.Cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
}
proxy.Transport = conn.HTTPTransport()
return proxy, release, nil
}
func (a *AgentProvider) Close() error {
return a.Cache.Close()
}
// New creates a new workspace connection cache that closes connections after
// the inactive timeout provided.
//
// Agent connections are cached due to WebRTC negotiation
// taking a few hundred milliseconds.
// Agent connections are cached due to Wireguard negotiation taking a few
// hundred milliseconds, depending on latency.
//
// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased
// out because it creates a unique Tailnet for each agent.
// See: https://github.com/coder/coder/issues/8218
func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
if inactiveTimeout == 0 {
inactiveTimeout = 5 * time.Minute

View File

@ -157,22 +157,23 @@ func TestCache(t *testing.T) {
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
t.Helper()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
manifest.DERPMap = tailnettest.RunDERPAndSTUN(t)
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
agentID := uuid.New()
manifest.AgentID = uuid.New()
closer := agent.New(agent.Options{
Client: &client{
t: t,
agentID: agentID,
agentID: manifest.AgentID,
manifest: manifest,
coordinator: coordinator,
},
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
})
t.Cleanup(func() {
_ = closer.Close()
@ -189,14 +190,15 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
_ = serverConn.Close()
_ = conn.Close()
})
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
return conn.UpdateNodes(node, false)
})
conn.SetNodeCallback(sendNode)
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
}
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
})
t.Cleanup(func() {
_ = agentConn.Close()
})