fix: use a background context when piping derp connections (#6750)

This was causing boatloads of connects to reestablish every time...

See https://github.com/coder/coder/issues/6746
This commit is contained in:
Kyle Carberry
2023-03-23 09:54:07 -05:00
committed by GitHub
parent 7949db8e03
commit ed9a3b9251
4 changed files with 21 additions and 21 deletions

View File

@ -294,7 +294,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
go httpapi.Heartbeat(ctx, conn)
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return
@ -339,7 +339,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
return
}
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error dialing workspace agent.",
@ -414,10 +414,8 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
}
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
ctx := r.Context()
func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
clientConn, serverConn := net.Pipe()
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: api.DERPMap,
@ -428,6 +426,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
_ = serverConn.Close()
return nil, xerrors.Errorf("create tailnet conn: %w", err)
}
ctx, cancel := context.WithCancel(api.ctx)
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay {
return nil
@ -437,7 +436,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
defer left.Close()
defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
api.DERPServer.Accept(ctx, right, brw, r.RemoteAddr)
api.DERPServer.Accept(ctx, right, brw, "internal")
}()
return left
})
@ -453,6 +452,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn,
CloseFunc: func() {
cancel()
_ = clientConn.Close()
_ = serverConn.Close()
},
@ -460,7 +460,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err))
_ = agentConn.Close()
}
}()

View File

@ -639,7 +639,7 @@ func (api *API) proxyWorkspaceApplication(rw http.ResponseWriter, r *http.Reques
})
}
conn, release, err := api.workspaceAgentCache.Acquire(r, ticket.AgentID)
conn, release, err := api.workspaceAgentCache.Acquire(ticket.AgentID)
if err != nil {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadGateway,

View File

@ -32,7 +32,7 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
}
// Dialer creates a new agent connection by ID.
type Dialer func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
type Dialer func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
// Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct {
@ -78,7 +78,7 @@ type Cache struct {
// The returned function is used to release a lock on the connection. Once zero
// locks exist on a connection, the inactive timeout will begin to tick down.
// After the time expires, the connection will be cleared from the cache.
func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
func (c *Cache) Acquire(id uuid.UUID) (*Conn, func(), error) {
rawConn, found := c.connMap.Load(id.String())
// If the connection isn't found, establish a new one!
if !found {
@ -95,7 +95,7 @@ func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
}
c.closeGroup.Add(1)
c.closeMutex.Unlock()
agentConn, err := c.dialer(r, id)
agentConn, err := c.dialer(id)
if err != nil {
c.closeGroup.Done()
return nil, xerrors.Errorf("dial: %w", err)

View File

@ -40,33 +40,33 @@ func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, 0)
defer func() {
_ = cache.Close()
}()
conn1, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
conn1, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
conn2, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
conn2, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
require.True(t, conn1 == conn2)
})
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
called.Add(1)
return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
conn, release, err = cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
conn, release, err = cache.Acquire(uuid.Nil)
require.NoError(t, err)
release()
<-conn.Closed()
@ -74,13 +74,13 @@ func TestCache(t *testing.T) {
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
_ = cache.Close()
}()
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err)
time.Sleep(time.Millisecond)
release()
@ -107,7 +107,7 @@ func TestCache(t *testing.T) {
}()
go server.Serve(random)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
@ -130,7 +130,7 @@ func TestCache(t *testing.T) {
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(ctx)
conn, release, err := cache.Acquire(req, uuid.Nil)
conn, release, err := cache.Acquire(uuid.Nil)
if !assert.NoError(t, err) {
return
}