mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
fix: use a background context when piping derp connections (#6750)
This was causing boatloads of connects to reestablish every time... See https://github.com/coder/coder/issues/6746
This commit is contained in:
@ -294,7 +294,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
go httpapi.Heartbeat(ctx, conn)
|
go httpapi.Heartbeat(ctx, conn)
|
||||||
|
|
||||||
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
|
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||||
return
|
return
|
||||||
@ -339,7 +339,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
|
agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: "Internal error dialing workspace agent.",
|
Message: "Internal error dialing workspace agent.",
|
||||||
@ -414,10 +414,8 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
|
|||||||
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
|
httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||||
ctx := r.Context()
|
|
||||||
clientConn, serverConn := net.Pipe()
|
clientConn, serverConn := net.Pipe()
|
||||||
|
|
||||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||||
DERPMap: api.DERPMap,
|
DERPMap: api.DERPMap,
|
||||||
@ -428,6 +426,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
|
|||||||
_ = serverConn.Close()
|
_ = serverConn.Close()
|
||||||
return nil, xerrors.Errorf("create tailnet conn: %w", err)
|
return nil, xerrors.Errorf("create tailnet conn: %w", err)
|
||||||
}
|
}
|
||||||
|
ctx, cancel := context.WithCancel(api.ctx)
|
||||||
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
|
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
|
||||||
if !region.EmbeddedRelay {
|
if !region.EmbeddedRelay {
|
||||||
return nil
|
return nil
|
||||||
@ -437,7 +436,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
|
|||||||
defer left.Close()
|
defer left.Close()
|
||||||
defer right.Close()
|
defer right.Close()
|
||||||
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
|
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
|
||||||
api.DERPServer.Accept(ctx, right, brw, r.RemoteAddr)
|
api.DERPServer.Accept(ctx, right, brw, "internal")
|
||||||
}()
|
}()
|
||||||
return left
|
return left
|
||||||
})
|
})
|
||||||
@ -453,6 +452,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
|
|||||||
agentConn := &codersdk.WorkspaceAgentConn{
|
agentConn := &codersdk.WorkspaceAgentConn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
CloseFunc: func() {
|
CloseFunc: func() {
|
||||||
|
cancel()
|
||||||
_ = clientConn.Close()
|
_ = clientConn.Close()
|
||||||
_ = serverConn.Close()
|
_ = serverConn.Close()
|
||||||
},
|
},
|
||||||
@ -460,7 +460,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
|
|||||||
go func() {
|
go func() {
|
||||||
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
|
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
|
api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err))
|
||||||
_ = agentConn.Close()
|
_ = agentConn.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -639,7 +639,7 @@ func (api *API) proxyWorkspaceApplication(rw http.ResponseWriter, r *http.Reques
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, release, err := api.workspaceAgentCache.Acquire(r, ticket.AgentID)
|
conn, release, err := api.workspaceAgentCache.Acquire(ticket.AgentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||||
Status: http.StatusBadGateway,
|
Status: http.StatusBadGateway,
|
||||||
|
@ -32,7 +32,7 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dialer creates a new agent connection by ID.
|
// Dialer creates a new agent connection by ID.
|
||||||
type Dialer func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
|
type Dialer func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
|
||||||
|
|
||||||
// Conn wraps an agent connection with a reusable HTTP transport.
|
// Conn wraps an agent connection with a reusable HTTP transport.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
@ -78,7 +78,7 @@ type Cache struct {
|
|||||||
// The returned function is used to release a lock on the connection. Once zero
|
// The returned function is used to release a lock on the connection. Once zero
|
||||||
// locks exist on a connection, the inactive timeout will begin to tick down.
|
// locks exist on a connection, the inactive timeout will begin to tick down.
|
||||||
// After the time expires, the connection will be cleared from the cache.
|
// After the time expires, the connection will be cleared from the cache.
|
||||||
func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
|
func (c *Cache) Acquire(id uuid.UUID) (*Conn, func(), error) {
|
||||||
rawConn, found := c.connMap.Load(id.String())
|
rawConn, found := c.connMap.Load(id.String())
|
||||||
// If the connection isn't found, establish a new one!
|
// If the connection isn't found, establish a new one!
|
||||||
if !found {
|
if !found {
|
||||||
@ -95,7 +95,7 @@ func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
|
|||||||
}
|
}
|
||||||
c.closeGroup.Add(1)
|
c.closeGroup.Add(1)
|
||||||
c.closeMutex.Unlock()
|
c.closeMutex.Unlock()
|
||||||
agentConn, err := c.dialer(r, id)
|
agentConn, err := c.dialer(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.closeGroup.Done()
|
c.closeGroup.Done()
|
||||||
return nil, xerrors.Errorf("dial: %w", err)
|
return nil, xerrors.Errorf("dial: %w", err)
|
||||||
|
@ -40,33 +40,33 @@ func TestCache(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("Same", func(t *testing.T) {
|
t.Run("Same", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||||
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
||||||
}, 0)
|
}, 0)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = cache.Close()
|
_ = cache.Close()
|
||||||
}()
|
}()
|
||||||
conn1, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
|
conn1, _, err := cache.Acquire(uuid.Nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
conn2, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
|
conn2, _, err := cache.Acquire(uuid.Nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, conn1 == conn2)
|
require.True(t, conn1 == conn2)
|
||||||
})
|
})
|
||||||
t.Run("Expire", func(t *testing.T) {
|
t.Run("Expire", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
called := atomic.NewInt32(0)
|
called := atomic.NewInt32(0)
|
||||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||||
called.Add(1)
|
called.Add(1)
|
||||||
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
||||||
}, time.Microsecond)
|
}, time.Microsecond)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = cache.Close()
|
_ = cache.Close()
|
||||||
}()
|
}()
|
||||||
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
|
conn, release, err := cache.Acquire(uuid.Nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
release()
|
release()
|
||||||
<-conn.Closed()
|
<-conn.Closed()
|
||||||
conn, release, err = cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
|
conn, release, err = cache.Acquire(uuid.Nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
release()
|
release()
|
||||||
<-conn.Closed()
|
<-conn.Closed()
|
||||||
@ -74,13 +74,13 @@ func TestCache(t *testing.T) {
|
|||||||
})
|
})
|
||||||
t.Run("NoExpireWhenLocked", func(t *testing.T) {
|
t.Run("NoExpireWhenLocked", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||||
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
||||||
}, time.Microsecond)
|
}, time.Microsecond)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = cache.Close()
|
_ = cache.Close()
|
||||||
}()
|
}()
|
||||||
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil)
|
conn, release, err := cache.Acquire(uuid.Nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
release()
|
release()
|
||||||
@ -107,7 +107,7 @@ func TestCache(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
go server.Serve(random)
|
go server.Serve(random)
|
||||||
|
|
||||||
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
|
||||||
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
return setupAgent(t, agentsdk.Metadata{}, 0), nil
|
||||||
}, time.Microsecond)
|
}, time.Microsecond)
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -130,7 +130,7 @@ func TestCache(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
conn, release, err := cache.Acquire(req, uuid.Nil)
|
conn, release, err := cache.Acquire(uuid.Nil)
|
||||||
if !assert.NoError(t, err) {
|
if !assert.NoError(t, err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user