chore: remove legacy wsconncache (#11816)

Fixes #8218

Removes `wsconncache` and related "is legacy?" functions and API calls that were used by it.

The only leftover is that Agents still use the legacy IP, so that back level clients or workspace proxies can dial them correctly.

We should eventually remove this: #11819
This commit is contained in:
Spike Curtis
2024-01-30 07:56:36 +04:00
committed by GitHub
parent 13e24f21e4
commit 1e8a9c09fe
24 changed files with 36 additions and 1238 deletions

49
coderd/apidoc/docs.go generated
View File

@ -5822,44 +5822,6 @@ const docTemplate = `{
}
}
},
"/workspaceagents/{workspaceagent}/legacy": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Enterprise"
],
"summary": "Agent is legacy",
"operationId": "agent-is-legacy",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace Agent ID",
"name": "workspaceagent",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/wsproxysdk.AgentIsLegacyResponse"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspaceagents/{workspaceagent}/listening-ports": {
"get": {
"security": [
@ -13811,17 +13773,6 @@ const docTemplate = `{
}
}
},
"wsproxysdk.AgentIsLegacyResponse": {
"type": "object",
"properties": {
"found": {
"type": "boolean"
},
"legacy": {
"type": "boolean"
}
}
},
"wsproxysdk.DeregisterWorkspaceProxyRequest": {
"type": "object",
"properties": {

View File

@ -5120,40 +5120,6 @@
}
}
},
"/workspaceagents/{workspaceagent}/legacy": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Enterprise"],
"summary": "Agent is legacy",
"operationId": "agent-is-legacy",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace Agent ID",
"name": "workspaceagent",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/wsproxysdk.AgentIsLegacyResponse"
}
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/workspaceagents/{workspaceagent}/listening-ports": {
"get": {
"security": [
@ -12604,17 +12570,6 @@
}
}
},
"wsproxysdk.AgentIsLegacyResponse": {
"type": "object",
"properties": {
"found": {
"type": "boolean"
},
"legacy": {
"type": "boolean"
}
}
},
"wsproxysdk.DeregisterWorkspaceProxyRequest": {
"type": "object",
"properties": {

View File

@ -65,7 +65,6 @@ import (
"github.com/coder/coder/v2/coderd/updatecheck"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/provisionerd/proto"
@ -481,7 +480,6 @@ func New(options *Options) *API {
func(context.Context) (tailnet.MultiAgentConn, error) {
return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil
},
wsconncache.New(api._dialWorkspaceAgentTailnet, 0),
api.TracerProvider,
)
if err != nil {

View File

@ -22,7 +22,6 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/site"
"github.com/coder/coder/v2/tailnet"
@ -41,8 +40,7 @@ func init() {
var _ workspaceapps.AgentProvider = (*ServerTailnet)(nil)
// NewServerTailnet creates a new tailnet intended for use by coderd. It
// automatically falls back to wsconncache if a legacy agent is encountered.
// NewServerTailnet creates a new tailnet intended for use by coderd.
func NewServerTailnet(
ctx context.Context,
logger slog.Logger,
@ -50,7 +48,6 @@ func NewServerTailnet(
derpMapFn func() *tailcfg.DERPMap,
derpForceWebSockets bool,
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error),
cache *wsconncache.Cache,
traceProvider trace.TracerProvider,
) (*ServerTailnet, error) {
logger = logger.Named("servertailnet")
@ -97,7 +94,6 @@ func NewServerTailnet(
conn: conn,
coordinatee: conn,
getMultiAgent: getMultiAgent,
cache: cache,
agentConnectionTimes: map[uuid.UUID]time.Time{},
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
transport: tailnetTransport.Clone(),
@ -299,7 +295,6 @@ type ServerTailnet struct {
getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error)
agentConn atomic.Pointer[tailnet.MultiAgentConn]
cache *wsconncache.Cache
nodesMu sync.Mutex
// agentConnectionTimes is a map of agent tailnetNodes the server wants to
// keep a connection to. It contains the last time the agent was connected
@ -311,7 +306,7 @@ type ServerTailnet struct {
transport *http.Transport
}
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error) {
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
@ -325,7 +320,7 @@ func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID u
proxy.Director = s.director(agentID, proxy.Director)
proxy.Transport = s.transport
return proxy, func() {}, nil
return proxy
}
type agentIDKey struct{}
@ -387,28 +382,17 @@ func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*code
ret func()
)
if s.getAgentConn().AgentIsLegacy(agentID) {
s.logger.Debug(s.ctx, "acquiring legacy agent", slog.F("agent_id", agentID))
cconn, release, err := s.cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire legacy agent conn: %w", err)
}
conn = cconn.WorkspaceAgentConn
ret = release
} else {
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return codersdk.ErrSkipClose },
})
s.logger.Debug(s.ctx, "acquiring agent", slog.F("agent_id", agentID))
err := s.ensureAgent(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("ensure agent: %w", err)
}
ret = s.acquireTicket(agentID)
conn = codersdk.NewWorkspaceAgentConn(s.conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return codersdk.ErrSkipClose },
})
// Since we now have an open conn, be careful to close it if we error
// without returning it to the user.
@ -458,7 +442,6 @@ func (c *netConnCloser) Close() error {
func (s *ServerTailnet) Close() error {
s.cancel()
_ = s.cache.Close()
_ = s.conn.Close()
s.transport.CloseIdleConnections()
<-s.derpMapUpdaterClosed

View File

@ -21,7 +21,6 @@ import (
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
@ -45,24 +44,6 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
assert.True(t, conn.AwaitReachable(ctx))
}
func TestServerTailnet_AgentConn_Legacy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
})
conn, release, err := serverTailnet.AgentConn(ctx, agentID)
require.NoError(t, err)
defer release()
assert.True(t, conn.AwaitReachable(ctx))
}
func TestServerTailnet_ReverseProxy(t *testing.T) {
t.Parallel()
@ -77,9 +58,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
require.NoError(t, err)
defer release()
rp := serverTailnet.ReverseProxy(u, u, agentID)
rw := httptest.NewRecorder()
req := httptest.NewRequest(
@ -113,9 +92,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
uri, err := url.Parse(s.URL)
require.NoError(t, err)
rp, release, err := serverTailnet.ReverseProxy(uri, uri, agentID)
require.NoError(t, err)
defer release()
rp := serverTailnet.ReverseProxy(uri, uri, agentID)
rw := httptest.NewRecorder()
req := httptest.NewRequest(
@ -130,38 +107,6 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
assert.Equal(t, expectedResponseCode, res.StatusCode)
})
t.Run("Legacy", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Force a connection through wsconncache using the legacy hardcoded ip.
agentID, _, serverTailnet := setupAgent(t, []netip.Prefix{
netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128),
})
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
require.NoError(t, err)
rp, release, err := serverTailnet.ReverseProxy(u, u, agentID)
require.NoError(t, err)
defer release()
rw := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
u.String(),
nil,
).WithContext(ctx)
rp.ServeHTTP(rw, req)
res := rw.Result()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
})
}
func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.Agent, *coderd.ServerTailnet) {
@ -197,34 +142,6 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
return coord.Node(manifest.AgentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: manifest.DERPMap,
Logger: logger.Named("client"),
})
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
clientID := uuid.New()
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
coordination := tailnet.NewInMemoryCoordination(
testCtx, logger,
clientID, manifest.AgentID,
coord, conn,
)
t.Cleanup(func() {
_ = coordination.Close()
})
return codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error { return codersdk.ErrSkipClose },
}), nil
}, 0)
serverTailnet, err := coderd.NewServerTailnet(
context.Background(),
logger,
@ -232,7 +149,6 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
func() *tailcfg.DERPMap { return manifest.DERPMap },
false,
func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil },
cache,
trace.NewNoopTracerProvider(),
)
require.NoError(t, err)

View File

@ -1,7 +1,6 @@
package coderd
import (
"bufio"
"context"
"database/sql"
"encoding/json"
@ -10,7 +9,6 @@ import (
"io"
"net"
"net/http"
"net/netip"
"net/url"
"sort"
"strconv"
@ -861,81 +859,6 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
}
// Deprecated: use api.tailnet.AgentConn instead.
// See: https://github.com/coder/coder/issues/8218
func (api *API) _dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
derpMap := api.DERPMap()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: api.DERPMap(),
DERPForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(),
Logger: api.Logger.Named("net.tailnet"),
BlockEndpoints: api.DeploymentValues.DERP.Config.BlockDirect.Value(),
})
if err != nil {
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
ctx, cancel := context.WithCancel(api.ctx)
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
}
left, right := net.Pipe()
go func() {
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
api.DERPServer.Accept(ctx, right, brw, "internal")
}()
return left
})
clientID := uuid.New()
coordination := tailnet.NewInMemoryCoordination(ctx, api.Logger,
clientID, agentID,
*(api.TailnetCoordinator.Load()), conn)
// Check for updated DERP map every 5 seconds.
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
lastDERPMap := derpMap
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
derpMap := api.DERPMap()
if lastDERPMap == nil || !tailnet.CompareDERPMaps(lastDERPMap, derpMap) {
conn.SetDERPMap(derpMap)
lastDERPMap = derpMap
}
ticker.Reset(5 * time.Second)
}
}
}()
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
AgentIP: codersdk.WorkspaceAgentIP,
CloseFunc: func() error {
_ = coordination.Close()
cancel()
return nil
},
})
if !agentConn.AwaitReachable(ctx) {
_ = agentConn.Close()
cancel()
return nil, xerrors.Errorf("agent not reachable")
}
return agentConn, nil
}
// @Summary Get connection info for workspace agent
// @ID get-connection-info-for-workspace-agent
// @Security CoderSessionToken

View File

@ -65,14 +65,9 @@ var nonCanonicalHeaders = map[string]string{
type AgentProvider interface {
// ReverseProxy returns an httputil.ReverseProxy for proxying HTTP requests
// to the specified agent.
//
// TODO: after wsconncache is deleted this doesn't need to return an error.
ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) (_ *httputil.ReverseProxy, release func(), _ error)
ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy
// AgentConn returns a new connection to the specified agent.
//
// TODO: after wsconncache is deleted this doesn't need to return a release
// func.
AgentConn(ctx context.Context, agentID uuid.UUID) (_ *codersdk.WorkspaceAgentConn, release func(), _ error)
ServeHTTPDebug(w http.ResponseWriter, r *http.Request)
@ -548,18 +543,7 @@ func (s *Server) proxyWorkspaceApp(rw http.ResponseWriter, r *http.Request, appT
r.URL.Path = path
appURL.RawQuery = ""
proxy, release, err := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID)
if err != nil {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Could not connect to workspace agent: " + err.Error(),
RetryEnabled: true,
DashboardURL: s.DashboardURL.String(),
})
return
}
defer release()
proxy := s.AgentProvider.ReverseProxy(appURL, s.DashboardURL, appToken.AgentID)
proxy.ModifyResponse = func(r *http.Response) error {
r.Header.Del(httpmw.AccessControlAllowOriginHeader)

View File

@ -1,232 +0,0 @@
// Package wsconncache caches workspace agent connections by UUID.
// Deprecated: Use ServerTailnet instead.
package wsconncache
import (
"context"
"crypto/tls"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/atomic"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/site"
)
var _ workspaceapps.AgentProvider = (*AgentProvider)(nil)
type AgentProvider struct {
Cache *Cache
}
func (a *AgentProvider) AgentConn(_ context.Context, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, func(), error) {
conn, rel, err := a.Cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
}
return conn.WorkspaceAgentConn, rel, nil
}
func (a *AgentProvider) ReverseProxy(targetURL *url.URL, dashboardURL *url.URL, agentID uuid.UUID) (*httputil.ReverseProxy, func(), error) {
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
site.RenderStaticErrorPage(w, r, site.ErrorPageData{
Status: http.StatusBadGateway,
Title: "Bad Gateway",
Description: "Failed to proxy request to application: " + err.Error(),
RetryEnabled: true,
DashboardURL: dashboardURL.String(),
})
}
conn, release, err := a.Cache.Acquire(agentID)
if err != nil {
return nil, nil, xerrors.Errorf("acquire agent connection: %w", err)
}
transport := conn.HTTPTransport()
proxy.Transport = transport
return proxy, release, nil
}
func (*AgentProvider) ServeHTTPDebug(http.ResponseWriter, *http.Request) {}
func (a *AgentProvider) Close() error {
return a.Cache.Close()
}
// New creates a new workspace connection cache that closes connections after
// the inactive timeout provided.
//
// Agent connections are cached due to Wireguard negotiation taking a few
// hundred milliseconds, depending on latency.
//
// Deprecated: Use coderd.NewServerTailnet instead. wsconncache is being phased
// out because it creates a unique Tailnet for each agent.
// See: https://github.com/coder/coder/issues/8218
func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
if inactiveTimeout == 0 {
inactiveTimeout = 5 * time.Minute
}
return &Cache{
closed: make(chan struct{}),
dialer: dialer,
inactiveTimeout: inactiveTimeout,
}
}
// Dialer creates a new agent connection by ID.
type Dialer func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
// Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct {
*codersdk.WorkspaceAgentConn
locks atomic.Uint64
timeoutMutex sync.Mutex
timeout *time.Timer
timeoutCancel context.CancelFunc
transport *http.Transport
}
func (c *Conn) HTTPTransport() *http.Transport {
return c.transport
}
// Close ends the HTTP transport if exists, and closes the agent.
func (c *Conn) Close() error {
if c.transport != nil {
c.transport.CloseIdleConnections()
}
c.timeoutMutex.Lock()
defer c.timeoutMutex.Unlock()
if c.timeout != nil {
c.timeout.Stop()
}
return c.WorkspaceAgentConn.Close()
}
type Cache struct {
closed chan struct{}
closeMutex sync.Mutex
closeGroup sync.WaitGroup
connGroup singleflight.Group
connMap sync.Map
dialer Dialer
inactiveTimeout time.Duration
}
// Acquire gets or establishes a connection with the dialer using the ID provided.
// If a connection is in-progress, that connection or error will be returned.
//
// The returned function is used to release a lock on the connection. Once zero
// locks exist on a connection, the inactive timeout will begin to tick down.
// After the time expires, the connection will be cleared from the cache.
func (c *Cache) Acquire(id uuid.UUID) (*Conn, func(), error) {
rawConn, found := c.connMap.Load(id.String())
// If the connection isn't found, establish a new one!
if !found {
var err error
// A singleflight group is used to allow for concurrent requests to the
// same identifier to resolve.
rawConn, err, _ = c.connGroup.Do(id.String(), func() (interface{}, error) {
c.closeMutex.Lock()
select {
case <-c.closed:
c.closeMutex.Unlock()
return nil, xerrors.New("closed")
default:
}
c.closeGroup.Add(1)
c.closeMutex.Unlock()
agentConn, err := c.dialer(id)
if err != nil {
c.closeGroup.Done()
return nil, xerrors.Errorf("dial: %w", err)
}
timeoutCtx, timeoutCancelFunc := context.WithCancel(context.Background())
defaultTransport, valid := http.DefaultTransport.(*http.Transport)
if !valid {
panic("dev error: default transport is the wrong type")
}
transport := defaultTransport.Clone()
transport.DialContext = agentConn.DialContext
// We intentionally don't verify the certificate chain here.
// The connection to the workspace is already established and most
// apps are already going to be accessed over plain HTTP, this config
// simply allows apps being run over HTTPS to be accessed without error --
// many of which may be using self-signed certs.
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
//nolint:gosec
InsecureSkipVerify: true,
}
conn := &Conn{
WorkspaceAgentConn: agentConn,
timeoutCancel: timeoutCancelFunc,
transport: transport,
}
go func() {
defer c.closeGroup.Done()
select {
case <-timeoutCtx.Done():
case <-c.closed:
case <-conn.Closed():
}
c.connMap.Delete(id.String())
c.connGroup.Forget(id.String())
transport.CloseIdleConnections()
_ = conn.Close()
}()
return conn, nil
})
if err != nil {
return nil, nil, err
}
c.connMap.Store(id.String(), rawConn)
}
conn, _ := rawConn.(*Conn)
conn.timeoutMutex.Lock()
defer conn.timeoutMutex.Unlock()
if conn.timeout != nil {
conn.timeout.Stop()
}
conn.locks.Inc()
return conn, func() {
conn.timeoutMutex.Lock()
defer conn.timeoutMutex.Unlock()
if conn.timeout != nil {
conn.timeout.Stop()
}
conn.locks.Dec()
if conn.locks.Load() == 0 {
conn.timeout = time.AfterFunc(c.inactiveTimeout, conn.timeoutCancel)
}
}, nil
}
func (c *Cache) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return nil
default:
}
close(c.closed)
c.closeGroup.Wait()
return nil
}

View File

@ -1,336 +0,0 @@
package wsconncache_test
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/xerrors"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0)
}, 0)
defer func() {
_ = cache.Close()
}()
conn1, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
conn2, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
require.True(t, conn1 == conn2)
})
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := int32(0)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
atomic.AddInt32(&called, 1)
return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
conn, release, err = cache.Acquire(uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
require.Equal(t, int32(2), called)
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
time.Sleep(time.Millisecond)
release()
<-conn.Closed()
})
t.Run("HTTPTransport", func(t *testing.T) {
t.Parallel()
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
_ = random.Close()
}()
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
server := &http.Server{
ReadHeaderTimeout: time.Minute,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
}
defer func() {
_ = server.Close()
}()
go server.Serve(random)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0)
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
var wg sync.WaitGroup
// Perform many requests in parallel to simulate
// simultaneous HTTP requests.
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port),
Path: "/",
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(ctx)
conn, release, err := cache.Acquire(uuid.Nil)
if !assert.NoError(t, err) {
return
}
defer release()
if !conn.AwaitReachable(ctx) {
t.Error("agent not reachable")
return
}
transport := conn.HTTPTransport()
defer transport.CloseIdleConnections()
proxy.Transport = transport
res := httptest.NewRecorder()
proxy.ServeHTTP(res, req)
resp := res.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}()
}
wg.Wait()
})
}
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) (*codersdk.WorkspaceAgentConn, error) {
t.Helper()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
manifest.AgentID = uuid.New()
aC := newClient(
t,
slogtest.Make(t, nil).Leveled(slog.LevelDebug),
manifest,
coordinator,
)
t.Cleanup(aC.close)
closer := agent.New(agent.Options{
Client: aC,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
})
t.Cleanup(func() {
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: manifest.DERPMap,
DERPForceWebSockets: manifest.DERPForceWebSockets,
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
})
// setupAgent is called by wsconncache Dialer, so we can't use require here as it will end the
// test, which in turn closes the wsconncache, which in turn waits for the Dialer and deadlocks.
if !assert.NoError(t, err) {
return nil, err
}
t.Cleanup(func() {
_ = conn.Close()
})
clientID := uuid.New()
testCtx, testCtxCancel := context.WithCancel(context.Background())
t.Cleanup(testCtxCancel)
coordination := tailnet.NewInMemoryCoordination(
testCtx, logger,
clientID, manifest.AgentID,
coordinator, conn,
)
t.Cleanup(func() {
_ = coordination.Close()
})
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
})
t.Cleanup(func() {
_ = agentConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
// setupAgent is called by wsconncache Dialer, so we can't use t.Fatal here as it will end
// the test, which in turn closes the wsconncache, which in turn waits for the Dialer and
// deadlocks.
t.Error("agent not reachable")
return nil, xerrors.New("agent not reachable")
}
return agentConn, nil
}
type client struct {
t *testing.T
agentID uuid.UUID
manifest agentsdk.Manifest
coordinator tailnet.Coordinator
closeOnce sync.Once
derpMapUpdates chan *tailcfg.DERPMap
server *drpcserver.Server
fakeAgentAPI *agenttest.FakeAgentAPI
}
func newClient(t *testing.T, logger slog.Logger, manifest agentsdk.Manifest, coordinator tailnet.Coordinator) *client {
logger = logger.Named("drpc")
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coordinator)
mux := drpcmux.New()
derpMapUpdates := make(chan *tailcfg.DERPMap)
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger,
DerpMapUpdateFrequency: time.Microsecond,
DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates },
}
err := proto.DRPCRegisterTailnet(mux, drpcService)
require.NoError(t, err)
fakeAAPI := agenttest.NewFakeAgentAPI(t, logger)
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
return &client{
t: t,
agentID: manifest.AgentID,
manifest: manifest,
coordinator: coordinator,
derpMapUpdates: derpMapUpdates,
server: server,
fakeAgentAPI: fakeAAPI,
}
}
func (c *client) close() {
c.closeOnce.Do(func() { close(c.derpMapUpdates) })
}
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
}
func (c *client) Listen(_ context.Context) (drpc.Conn, error) {
conn, lis := drpcsdk.MemTransportPipe()
c.t.Cleanup(func() {
_ = conn.Close()
_ = lis.Close()
})
serveCtx, cancel := context.WithCancel(context.Background())
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Name: "wsconncache_test-agent",
ID: c.agentID,
Auth: auth,
}
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() {
c.server.Serve(serveCtx, lis)
}()
return conn, nil
}
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {
return io.NopCloser(strings.NewReader("")), nil
}
func (*client) PostLifecycle(_ context.Context, _ agentsdk.PostLifecycleRequest) error {
return nil
}
func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
return nil
}
func (*client) PostMetadata(_ context.Context, _ agentsdk.PostMetadataRequest) error {
return nil
}
func (*client) PostStartup(_ context.Context, _ agentsdk.PostStartupRequest) error {
return nil
}
func (*client) PatchLogs(_ context.Context, _ agentsdk.PatchLogs) error {
return nil
}