mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
chore: replace wsconncache with a single tailnet (#8176)
This commit is contained in:
5
coderd/apidoc/docs.go
generated
5
coderd/apidoc/docs.go
generated
@ -5961,6 +5961,9 @@ const docTemplate = `{
|
||||
"agentsdk.Manifest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"apps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
@ -7617,6 +7620,7 @@ const docTemplate = `{
|
||||
"workspace_actions",
|
||||
"tailnet_ha_coordinator",
|
||||
"convert-to-oidc",
|
||||
"single_tailnet",
|
||||
"workspace_build_logs_ui"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
@ -7624,6 +7628,7 @@ const docTemplate = `{
|
||||
"ExperimentWorkspaceActions",
|
||||
"ExperimentTailnetHACoordinator",
|
||||
"ExperimentConvertToOIDC",
|
||||
"ExperimentSingleTailnet",
|
||||
"ExperimentWorkspaceBuildLogsUI"
|
||||
]
|
||||
},
|
||||
|
5
coderd/apidoc/swagger.json
generated
5
coderd/apidoc/swagger.json
generated
@ -5251,6 +5251,9 @@
|
||||
"agentsdk.Manifest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"apps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
@ -6818,6 +6821,7 @@
|
||||
"workspace_actions",
|
||||
"tailnet_ha_coordinator",
|
||||
"convert-to-oidc",
|
||||
"single_tailnet",
|
||||
"workspace_build_logs_ui"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
@ -6825,6 +6829,7 @@
|
||||
"ExperimentWorkspaceActions",
|
||||
"ExperimentTailnetHACoordinator",
|
||||
"ExperimentConvertToOIDC",
|
||||
"ExperimentSingleTailnet",
|
||||
"ExperimentWorkspaceBuildLogsUI"
|
||||
]
|
||||
},
|
||||
|
@ -364,8 +364,23 @@ func New(options *Options) *API {
|
||||
}
|
||||
|
||||
api.Auditor.Store(&options.Auditor)
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
|
||||
if api.Experiments.Enabled(codersdk.ExperimentSingleTailnet) {
|
||||
api.agentProvider, err = NewServerTailnet(api.ctx,
|
||||
options.Logger,
|
||||
options.DERPServer,
|
||||
options.DERPMap,
|
||||
&api.TailnetCoordinator,
|
||||
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
|
||||
)
|
||||
if err != nil {
|
||||
panic("failed to setup server tailnet: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
api.agentProvider = &wsconncache.AgentProvider{
|
||||
Cache: wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
|
||||
}
|
||||
}
|
||||
|
||||
api.workspaceAppServer = &workspaceapps.Server{
|
||||
Logger: options.Logger.Named("workspaceapps"),
|
||||
@ -377,7 +392,7 @@ func New(options *Options) *API {
|
||||
RealIPConfig: options.RealIPConfig,
|
||||
|
||||
SignedTokenProvider: api.WorkspaceAppsProvider,
|
||||
WorkspaceConnCache: api.workspaceAgentCache,
|
||||
AgentProvider: api.agentProvider,
|
||||
AppSecurityKey: options.AppSecurityKey,
|
||||
|
||||
DisablePathApps: options.DeploymentValues.DisablePathApps.Value(),
|
||||
@ -921,10 +936,10 @@ type API struct {
|
||||
derpCloseFunc func()
|
||||
|
||||
metricsCache *metricscache.Cache
|
||||
workspaceAgentCache *wsconncache.Cache
|
||||
updateChecker *updatecheck.Checker
|
||||
WorkspaceAppsProvider workspaceapps.SignedTokenProvider
|
||||
workspaceAppServer *workspaceapps.Server
|
||||
agentProvider workspaceapps.AgentProvider
|
||||
|
||||
// Experiments contains the list of experiments currently enabled.
|
||||
// This is used to gate features that are not yet ready for production.
|
||||
@ -951,7 +966,8 @@ func (api *API) Close() error {
|
||||
if coordinator != nil {
|
||||
_ = (*coordinator).Close()
|
||||
}
|
||||
return api.workspaceAgentCache.Close()
|
||||
_ = api.agentProvider.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func compressHandler(h http.Handler) http.Handler {
|
||||
|
@ -109,6 +109,7 @@ type Options struct {
|
||||
GitAuthConfigs []*gitauth.Config
|
||||
TrialGenerator func(context.Context, string) error
|
||||
TemplateScheduleStore schedule.TemplateScheduleStore
|
||||
Coordinator tailnet.Coordinator
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthcheck.Report
|
||||
HealthcheckTimeout time.Duration
|
||||
|
@ -302,7 +302,7 @@ func TestAgents(t *testing.T) {
|
||||
coordinator := tailnet.NewCoordinator(slogtest.Make(t, nil).Leveled(slog.LevelDebug))
|
||||
coordinatorPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordinatorPtr.Store(&coordinator)
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
agentInactiveDisconnectTimeout := 1 * time.Hour // don't need to focus on this value in tests
|
||||
registry := prometheus.NewRegistry()
|
||||
|
||||
|
339
coderd/tailnet.go
Normal file
339
coderd/tailnet.go
Normal file
@ -0,0 +1,339 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
var tailnetTransport *http.Transport
|
||||
|
||||
func init() {
|
||||
var valid bool
|
||||
tailnetTransport, valid = http.DefaultTransport.(*http.Transport)
|
||||
if !valid {
|
||||
panic("dev error: default transport is the wrong type")
|
||||
}
|
||||
}
|
||||
|
||||
// NewServerTailnet creates a new tailnet intended for use by coderd. It
|
||||
// automatically falls back to wsconncache if a legacy agent is encountered.
|
||||
func NewServerTailnet(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
derpServer *derp.Server,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
coord *atomic.Pointer[tailnet.Coordinator],
|
||||
cache *wsconncache.Cache,
|
||||
) (*ServerTailnet, error) {
|
||||
logger = logger.Named("servertailnet")
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: derpMap,
|
||||
Logger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create tailnet conn: %w", err)
|
||||
}
|
||||
|
||||
serverCtx, cancel := context.WithCancel(ctx)
|
||||
tn := &ServerTailnet{
|
||||
ctx: serverCtx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
conn: conn,
|
||||
coord: coord,
|
||||
cache: cache,
|
||||
agentNodes: map[uuid.UUID]time.Time{},
|
||||
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
|
||||
transport: tailnetTransport.Clone(),
|
||||
}
|
||||
tn.transport.DialContext = tn.dialContext
|
||||
tn.transport.MaxIdleConnsPerHost = 10
|
||||
tn.transport.MaxIdleConns = 0
|
||||
agentConn := (*coord.Load()).ServeMultiAgent(uuid.New())
|
||||
tn.agentConn.Store(&agentConn)
|
||||
|
||||
err = tn.getAgentConn().UpdateSelf(conn.Node())
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err))
|
||||
}
|
||||
conn.SetNodeCallback(func(node *tailnet.Node) {
|
||||
err := tn.getAgentConn().UpdateSelf(node)
|
||||
if err != nil {
|
||||
tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err))
|
||||
}
|
||||
})
|
||||
|
||||
// This is set to allow local DERP traffic to be proxied through memory
|
||||
// instead of needing to hit the external access URL. Don't use the ctx
|
||||
// given in this callback, it's only valid while connecting.
|
||||
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
|
||||
if !region.EmbeddedRelay {
|
||||
return nil
|
||||
}
|
||||
left, right := net.Pipe()
|
||||
go func() {
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
|
||||
derpServer.Accept(ctx, right, brw, "internal")
|
||||
}()
|
||||
return left
|
||||
})
|
||||
|
||||
go tn.watchAgentUpdates()
|
||||
go tn.expireOldAgents()
|
||||
return tn, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) expireOldAgents() {
|
||||
const (
|
||||
tick = 5 * time.Minute
|
||||
cutoff = 30 * time.Minute
|
||||
)
|
||||
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.nodesMu.Lock()
|
||||
agentConn := s.getAgentConn()
|
||||
for agentID, lastConnection := range s.agentNodes {
|
||||
// If no one has connected since the cutoff and there are no active
|
||||
// connections, remove the agent.
|
||||
if time.Since(lastConnection) > cutoff && len(s.agentTickets[agentID]) == 0 {
|
||||
_ = agentConn
|
||||
// err := agentConn.UnsubscribeAgent(agentID)
|
||||
// if err != nil {
|
||||
// s.logger.Error(s.ctx, "unsubscribe expired agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
// }
|
||||
// delete(s.agentNodes, agentID)
|
||||
|
||||
// TODO(coadler): actually remove from the netmap, then reenable
|
||||
// the above
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) watchAgentUpdates() {
|
||||
for {
|
||||
conn := s.getAgentConn()
|
||||
nodes, ok := conn.NextUpdate(s.ctx)
|
||||
if !ok {
|
||||
if conn.IsClosed() && s.ctx.Err() == nil {
|
||||
s.reinitCoordinator()
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.conn.UpdateNodes(nodes, false)
|
||||
if err != nil {
|
||||
s.logger.Error(context.Background(), "update node in server tailnet", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn {
|
||||
return *s.agentConn.Load()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) reinitCoordinator() {
|
||||
s.nodesMu.Lock()
|
||||
agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New())
|
||||
s.agentConn.Store(&agentConn)
|
||||
|
||||
// Resubscribe to all of the agents we're tracking.
|
||||
for agentID := range s.agentNodes {
|
||||
err := agentConn.SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID))
|
||||
}
|
||||
}
|
||||
s.nodesMu.Unlock()
|
||||
}
|
||||
|
||||
type ServerTailnet struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
|
||||
logger slog.Logger
|
||||
conn *tailnet.Conn
|
||||
coord *atomic.Pointer[tailnet.Coordinator]
|
||||
agentConn atomic.Pointer[tailnet.MultiAgentConn]
|
||||
cache *wsconncache.Cache
|
||||
nodesMu sync.Mutex
|
||||
// agentNodes is a map of agent tailnetNodes the server wants to keep a
|
||||
// connection to. It contains the last time the agent was connected to.
|
||||
agentNodes map[uuid.UUID]time.Time
|
||||
// agentTockets holds a map of all open connections to an agent.
|
||||
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) {
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
Title: "Bad Gateway",
|
||||
Description: "Failed to proxy request to application: " + err.Error(),
|
||||
RetryEnabled: true,
|
||||
DashboardURL: dashboardURL.String(),
|
||||
})
|
||||
}
|
||||
proxy.Director = s.director(agentID, proxy.Director)
|
||||
proxy.Transport = s.transport
|
||||
|
||||
return proxy, func() {}, nil
|
||||
}
|
||||
|
||||
type agentIDKey struct{}
|
||||
|
||||
// director makes sure agentIDKey is set on the context in the reverse proxy.
|
||||
// This allows the transport to correctly identify which agent to dial to.
|
||||
func (*ServerTailnet) director(agentID uuid.UUID, prev func(req *http.Request)) func(req *http.Request) {
|
||||
return func(req *http.Request) {
|
||||
ctx := context.WithValue(req.Context(), agentIDKey{}, agentID)
|
||||
*req = *req.WithContext(ctx)
|
||||
prev(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
agentID, ok := ctx.Value(agentIDKey{}).(uuid.UUID)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("no agent id attached")
|
||||
}
|
||||
|
||||
return s.DialAgentNetConn(ctx, agentID, network, addr)
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
|
||||
s.nodesMu.Lock()
|
||||
defer s.nodesMu.Unlock()
|
||||
|
||||
_, ok := s.agentNodes[agentID]
|
||||
// If we don't have the node, subscribe.
|
||||
if !ok {
|
||||
s.logger.Debug(s.ctx, "subscribing to agent", slog.F("agent_id", agentID))
|
||||
err := s.getAgentConn().SubscribeAgent(agentID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe agent: %w", err)
|
||||
}
|
||||
s.agentTickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
}
|
||||
|
||||
s.agentNodes[agentID] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
|
||||
var (
|
||||
conn *codersdk.WorkspaceAgentConn
|
||||
ret = func() {}
|
||||
)
|
||||
|
||||
if s.getAgentConn().AgentIsLegacy(agentID) {
|
||||
s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID))
|
||||
cconn, release, err := s.cache.Acquire(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err)
|
||||
}
|
||||
|
||||
conn = cconn.WorkspaceAgentConn
|
||||
ret = release
|
||||
} else {
|
||||
err := s.ensureAgent(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
|
||||
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: agentID,
|
||||
CloseFunc: func() error { return codersdk.ErrSkipClose },
|
||||
})
|
||||
}
|
||||
|
||||
// Since we now have an open conn, be careful to close it if we error
|
||||
// without returning it to the user.
|
||||
|
||||
reachable := conn.AwaitReachable(ctx)
|
||||
if !reachable {
|
||||
ret()
|
||||
conn.Close()
|
||||
return nil, nil, xerrors.New("agent is unreachable")
|
||||
}
|
||||
|
||||
return conn, ret, nil
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) DialAgentNetConn(ctx context.Context, agentID uuid.UUID, network, addr string) (net.Conn, error) {
|
||||
conn, release, err := s.AgentConn(ctx, agentID)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("acquire agent conn: %w", err)
|
||||
}
|
||||
|
||||
// Since we now have an open conn, be careful to close it if we error
|
||||
// without returning it to the user.
|
||||
|
||||
nc, err := conn.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
release()
|
||||
conn.Close()
|
||||
return nil, xerrors.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
return &netConnCloser{Conn: nc, close: func() {
|
||||
release()
|
||||
conn.Close()
|
||||
}}, err
|
||||
}
|
||||
|
||||
type netConnCloser struct {
|
||||
net.Conn
|
||||
close func()
|
||||
}
|
||||
|
||||
func (c *netConnCloser) Close() error {
|
||||
c.close()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (s *ServerTailnet) Close() error {
|
||||
s.cancel()
|
||||
_ = s.cache.Close()
|
||||
_ = s.conn.Close()
|
||||
s.transport.CloseIdleConnections()
|
||||
return nil
|
||||
}
|
207
coderd/tailnet_test.go
Normal file
207
coderd/tailnet_test.go
Normal file
@ -0,0 +1,207 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/agent/agenttest"
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestServerTailnet_AgentConn_OK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
// Connect through the ServerTailnet
|
||||
agentID, _, serverTailnet := setupAgent(t, nil)
|
||||
|
||||
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
assert.True(t, conn.AwaitReachable(ctx))
|
||||
}
|
||||
|
||||
func TestServerTailnet_AgentConn_Legacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
// Force a connection through wsconncache using the legacy hardcoded ip.
|
||||
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
|
||||
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
|
||||
})
|
||||
|
||||
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
assert.True(t, conn.AwaitReachable(ctx))
|
||||
}
|
||||
|
||||
func TestServerTailnet_ReverseProxy_OK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Force a connection through wsconncache using the legacy hardcoded ip.
|
||||
agentID, _, serverTailnet := setupAgent(t, nil)
|
||||
|
||||
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
u.String(),
|
||||
nil,
|
||||
).WithContext(ctx)
|
||||
|
||||
rp.ServeHTTP(rw, req)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestServerTailnet_ReverseProxy_Legacy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Force a connection through wsconncache using the legacy hardcoded ip.
|
||||
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
|
||||
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
|
||||
})
|
||||
|
||||
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
|
||||
require.NoError(t, err)
|
||||
|
||||
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
|
||||
require.NoError(t, err)
|
||||
defer release()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
u.String(),
|
||||
nil,
|
||||
).WithContext(ctx)
|
||||
|
||||
rp.ServeHTTP(rw, req)
|
||||
res := rw.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
derpMap, derpServer := tailnettest.RunDERPAndSTUN(t)
|
||||
manifest := agentsdk.Manifest{
|
||||
AgentID: uuid.New(),
|
||||
DERPMap: derpMap,
|
||||
}
|
||||
|
||||
var coordPtr atomic.Pointer[tailnet.Coordinator]
|
||||
coord := tailnet.NewCoordinator(logger)
|
||||
coordPtr.Store(&coord)
|
||||
t.Cleanup(func() {
|
||||
_ = coord.Close()
|
||||
})
|
||||
|
||||
c := agenttest.NewClient(t, manifest.AgentID, manifest, make(chan *agentsdk.Stats, 50), coord)
|
||||
|
||||
options := agent.Options{
|
||||
Client: c,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
Logger: logger.Named("agent"),
|
||||
Addresses: agentAddresses,
|
||||
}
|
||||
|
||||
ag := agent.New(options)
|
||||
t.Cleanup(func() {
|
||||
_ = ag.Close()
|
||||
})
|
||||
|
||||
// Wait for the agent to connect.
|
||||
require.Eventually(t, func() bool {
|
||||
return coord.Node(manifest.AgentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: manifest.DERPMap,
|
||||
Logger: logger.Named("client"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
serveClientDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
<-serveClientDone
|
||||
})
|
||||
go func() {
|
||||
defer close(serveClientDone)
|
||||
coord.ServeClient(serverConn, uuid.New(), manifest.AgentID)
|
||||
}()
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node, false)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: manifest.AgentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
CloseFunc: func() error { return codersdk.ErrSkipClose },
|
||||
}), nil
|
||||
}, 0)
|
||||
|
||||
serverTailnet, err := coderd.NewServerTailnet(
|
||||
context.Background(),
|
||||
logger,
|
||||
derpServer,
|
||||
manifest.DERPMap,
|
||||
&coordPtr,
|
||||
cache,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = serverTailnet.Close()
|
||||
})
|
||||
|
||||
return manifest.AgentID, ag, serverTailnet
|
||||
}
|
@ -161,6 +161,7 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.Manifest{
|
||||
AgentID: apiAgent.ID,
|
||||
Apps: convertApps(dbApps),
|
||||
DERPMap: api.DERPMap,
|
||||
GitAuthConfigs: len(api.GitAuthConfigs),
|
||||
@ -654,7 +655,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
||||
return
|
||||
}
|
||||
|
||||
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
|
||||
agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error dialing workspace agent.",
|
||||
@ -729,7 +730,9 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
||||
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
|
||||
}
|
||||
|
||||
func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
// Deprecated: use api.tailnet.AgentConn instead.
|
||||
// See: https://github.com/coder/coder/issues/8218
|
||||
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
@ -765,14 +768,16 @@ func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.Workspac
|
||||
return nil
|
||||
})
|
||||
conn.SetNodeCallback(sendNodes)
|
||||
agentConn := &codersdk.WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
CloseFunc: func() {
|
||||
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: agentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
CloseFunc: func() error {
|
||||
cancel()
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
})
|
||||
go func() {
|
||||
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
|
||||
if err != nil {
|
||||
|
@ -399,7 +399,8 @@ func doWithRetries(t require.TestingT, client *codersdk.Client, req *http.Reques
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func requestWithRetries(ctx context.Context, t require.TestingT, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
|
||||
func requestWithRetries(ctx context.Context, t testing.TB, client *codersdk.Client, method, urlOrPath string, body interface{}, opts ...codersdk.RequestOption) (*http.Response, error) {
|
||||
t.Helper()
|
||||
var resp *http.Response
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
|
@ -23,7 +23,6 @@ import (
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/util/slice"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
)
|
||||
@ -61,6 +60,22 @@ var nonCanonicalHeaders = map[string]string{
|
||||
"Sec-Websocket-Version": "Sec-WebSocket-Version",
|
||||
}
|
||||
|
||||
type AgentProvider interface {
|
||||
// ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests
|
||||
// to the specified agent.
|
||||
//
|
||||
// TODO: after wsconncache is deleted this doesn't need to return an error.
|
||||
ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error)
|
||||
|
||||
// AgentConn returns a new connection to the specified agent.
|
||||
//
|
||||
// TODO: after wsconncache is deleted this doesn't need to return a release
|
||||
// func.
|
||||
AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Server serves workspace apps endpoints, including:
|
||||
// - Path-based apps
|
||||
// - Subdomain app middleware
|
||||
@ -83,7 +98,6 @@ type Server struct {
|
||||
RealIPConfig *httpmw.RealIPConfig
|
||||
|
||||
SignedTokenProvider SignedTokenProvider
|
||||
WorkspaceConnCache *wsconncache.Cache
|
||||
AppSecurityKey SecurityKey
|
||||
|
||||
// DisablePathApps disables path-based apps. This is a security feature as path
|
||||
@ -95,6 +109,8 @@ type Server struct {
|
||||
DisablePathApps bool
|
||||
SecureAuthCookie bool
|
||||
|
||||
AgentProvider AgentProvider
|
||||
|
||||
websocketWaitMutex sync.Mutex
|
||||
websocketWaitGroup sync.WaitGroup
|
||||
}
|
||||
@ -106,8 +122,8 @@ func (s *Server) Close() error {
|
||||
s.websocketWaitGroup.Wait()
|
||||
s.websocketWaitMutex.Unlock()
|
||||
|
||||
// The caller must close the SignedTokenProvider (if necessary) and the
|
||||
// wsconncache.
|
||||
// The caller must close the SignedTokenProvider and the AgentProvider (if
|
||||
// necessary).
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -517,18 +533,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
|
||||
r.URL.Path = path
|
||||
appURL.RawQuery = ""
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(appURL)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
Title: "Bad Gateway",
|
||||
Description: "Failed to proxy request to application: " + err.Error(),
|
||||
RetryEnabled: true,
|
||||
DashboardURL: s.DashboardURL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
conn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
|
||||
proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID)
|
||||
if err != nil {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
@ -540,7 +545,6 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
proxy.Transport = conn.HTTPTransport()
|
||||
|
||||
proxy.ModifyResponse = func(r *http.Response) error {
|
||||
r.Header.Del(httpmw.AccessControlAllowOriginHeader)
|
||||
@ -658,13 +662,14 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
||||
agentConn, release, err := s.WorkspaceConnCache.Acquire(appToken.AgentID)
|
||||
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
defer agentConn.Close()
|
||||
log.Debug(ctx, "dialed workspace agent")
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"))
|
||||
if err != nil {
|
||||
|
@ -1,9 +1,12 @@
|
||||
// Package wsconncache caches workspace agent connections by UUID.
|
||||
// Deprecated: Use ServerTailnet instead.
|
||||
package wsconncache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -13,13 +16,57 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
)
|
||||
|
||||
// New creates a new workspace connection cache that closes
|
||||
// connections after the inactive timeout provided.
|
||||
type AgentProvider struct {
|
||||
Cache *Cache
|
||||
}
|
||||
|
||||
func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
|
||||
conn, rel, err := a.Cache.Acquire(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
|
||||
}
|
||||
|
||||
return conn.WorkspaceAgentConn, rel, nil
|
||||
}
|
||||
|
||||
func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) {
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
|
||||
Status: http.StatusBadGateway,
|
||||
Title: "Bad Gateway",
|
||||
Description: "Failed to proxy request to application: " + err.Error(),
|
||||
RetryEnabled: true,
|
||||
DashboardURL: dashboardURL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
conn, release, err := a.Cache.Acquire(agentID)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
|
||||
}
|
||||
|
||||
proxy.Transport = conn.HTTPTransport()
|
||||
|
||||
return proxy, release, nil
|
||||
}
|
||||
|
||||
func (a *AgentProvider) Close() error {
|
||||
return a.Cache.Close()
|
||||
}
|
||||
|
||||
// New creates a new workspace connection cache that closes connections after
|
||||
// the inactive timeout provided.
|
||||
//
|
||||
// Agent connections are cached due to WebRTC negotiation
|
||||
// taking a few hundred milliseconds.
|
||||
// Agent connections are cached due to Wireguard negotiation taking a few
|
||||
// hundred milliseconds, depending on latency.
|
||||
//
|
||||
// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased
|
||||
// out because it creates a unique Tailnet for each agent.
|
||||
// See: https://github.com/coder/coder/issues/8218
|
||||
func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
|
||||
if inactiveTimeout == 0 {
|
||||
inactiveTimeout = 5 * time.Minute
|
||||
|
@ -157,22 +157,23 @@ func TestCache(t *testing.T) {
|
||||
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
manifest.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
t.Cleanup(func() {
|
||||
_ = coordinator.Close()
|
||||
})
|
||||
agentID := uuid.New()
|
||||
manifest.AgentID = uuid.New()
|
||||
closer := agent.New(agent.Options{
|
||||
Client: &client{
|
||||
t: t,
|
||||
agentID: agentID,
|
||||
agentID: manifest.AgentID,
|
||||
manifest: manifest,
|
||||
coordinator: coordinator,
|
||||
},
|
||||
Logger: logger.Named("agent"),
|
||||
ReconnectingPTYTimeout: ptyTimeout,
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = closer.Close()
|
||||
@ -189,14 +190,15 @@ func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Durati
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node, false)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
agentConn := &codersdk.WorkspaceAgentConn{
|
||||
Conn: conn,
|
||||
}
|
||||
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
|
||||
AgentID: manifest.AgentID,
|
||||
AgentIP: codersdk.WorkspaceAgentIP,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = agentConn.Close()
|
||||
})
|
||||
|
Reference in New Issue
Block a user