Files
coder/coderd/wsconncache/wsconncache_test.go
Spike Curtis ad3fed72bc chore: rename Coordinator to CoordinatorV1 (#11222)
Renames the tailnet.Coordinator to represent both v1 and v2 APIs, so that we can use this interface for the main atomic pointer.

Part of #10532
2023-12-15 11:38:12 +04:00

285 lines
7.4 KiB
Go

package wsconncache_test
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/coderd/wsconncache"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil
}, 0)
defer func() {
_ = cache.Close()
}()
conn1, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
conn2, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
require.True(t, conn1 == conn2)
})
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
called.Add(1)
return setupAgent(t, agentsdk.Manifest{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
conn, release, err = cache.Acquire(uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
require.Equal(t, int32(2), called.Load())
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
time.Sleep(time.Millisecond)
release()
<-conn.Closed()
})
t.Run("HTTPTransport", func(t *testing.T) {
t.Parallel()
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
_ = random.Close()
}()
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
server := &http.Server{
ReadHeaderTimeout: time.Minute,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
}
defer func() {
_ = server.Close()
}()
go server.Serve(random)
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Manifest{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
var wg sync.WaitGroup
// Perform many requests in parallel to simulate
// simultaneous HTTP requests.
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: fmt.Sprintf("127.0.0.1:%d", tcpAddr.Port),
Path: "/",
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(ctx)
conn, release, err := cache.Acquire(uuid.Nil)
if !assert.NoError(t, err) {
return
}
defer release()
if !conn.AwaitReachable(ctx) {
t.Error("agent not reachable")
return
}
transport := conn.HTTPTransport()
defer transport.CloseIdleConnections()
proxy.Transport = transport
res := httptest.NewRecorder()
proxy.ServeHTTP(res, req)
resp := res.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
}()
}
wg.Wait()
})
}
func setupAgent(t *testing.T, manifest agentsdk.Manifest, ptyTimeout time.Duration) *codersdk.WorkspaceAgentConn {
t.Helper()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
manifest.DERPMap, _ = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
manifest.AgentID = uuid.New()
closer := agent.New(agent.Options{
Client: &client{
t: t,
agentID: manifest.AgentID,
manifest: manifest,
coordinator: coordinator,
},
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)},
})
t.Cleanup(func() {
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: manifest.DERPMap,
DERPForceWebSockets: manifest.DERPForceWebSockets,
Logger: slogtest.Make(t, nil).Named("tailnet").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
_ = conn.Close()
})
go coordinator.ServeClient(serverConn, uuid.New(), manifest.AgentID)
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(nodes []*tailnet.Node) error {
return conn.UpdateNodes(nodes, false)
})
conn.SetNodeCallback(sendNode)
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: manifest.AgentID,
AgentIP: codersdk.WorkspaceAgentIP,
})
t.Cleanup(func() {
_ = agentConn.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn
}
type client struct {
t *testing.T
agentID uuid.UUID
manifest agentsdk.Manifest
coordinator tailnet.CoordinatorV1
}
func (c *client) Manifest(_ context.Context) (agentsdk.Manifest, error) {
return c.manifest, nil
}
type closer struct {
closeFunc func() error
}
func (c *closer) Close() error {
return c.closeFunc()
}
func (*client) DERPMapUpdates(_ context.Context) (<-chan agentsdk.DERPMapUpdate, io.Closer, error) {
closed := make(chan struct{})
return make(<-chan agentsdk.DERPMapUpdate), &closer{
closeFunc: func() error {
close(closed)
return nil
},
}, nil
}
func (c *client) Listen(_ context.Context) (net.Conn, error) {
clientConn, serverConn := net.Pipe()
closed := make(chan struct{})
c.t.Cleanup(func() {
_ = serverConn.Close()
_ = clientConn.Close()
<-closed
})
go func() {
_ = c.coordinator.ServeAgent(serverConn, c.agentID, "")
close(closed)
}()
return clientConn, nil
}
func (*client) ReportStats(_ context.Context, _ slog.Logger, _ <-chan *agentsdk.Stats, _ func(time.Duration)) (io.Closer, error) {
return io.NopCloser(strings.NewReader("")), nil
}
func (*client) PostLifecycle(_ context.Context, _ agentsdk.PostLifecycleRequest) error {
return nil
}
func (*client) PostAppHealth(_ context.Context, _ agentsdk.PostAppHealthsRequest) error {
return nil
}
func (*client) PostMetadata(_ context.Context, _ agentsdk.PostMetadataRequest) error {
return nil
}
func (*client) PostStartup(_ context.Context, _ agentsdk.PostStartupRequest) error {
return nil
}
func (*client) PatchLogs(_ context.Context, _ agentsdk.PatchLogs) error {
return nil
}
func (*client) GetServiceBanner(_ context.Context) (codersdk.ServiceBannerConfig, error) {
return codersdk.ServiceBannerConfig{}, nil
}