fix: use a background context when piping derp connections (#6750)

This was causing boatloads of connects to reestablish every time...

See https://github.com/coder/coder/issues/6746
This commit is contained in:
Kyle Carberry
2023-03-23 09:54:07 -05:00
committed by GitHub
parent 7949db8e03
commit ed9a3b9251
4 changed files with 21 additions and 21 deletions

View File

@ -294,7 +294,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
go httpapi.Heartbeat(ctx, conn) go httpapi.Heartbeat(ctx, conn)
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID) agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
return return
@ -339,7 +339,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
return return
} }
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID) agentConn, release, err := api.workspaceAgentCache.Acquire(workspaceAgent.ID)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error dialing workspace agent.", Message: "Internal error dialing workspace agent.",
@ -414,10 +414,8 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
httpapi.Write(ctx, rw, http.StatusOK, portsResponse) httpapi.Write(ctx, rw, http.StatusOK, portsResponse)
} }
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { func (api *API) dialWorkspaceAgentTailnet(agentID uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
ctx := r.Context()
clientConn, serverConn := net.Pipe() clientConn, serverConn := net.Pipe()
conn, err := tailnet.NewConn(&tailnet.Options{ conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: api.DERPMap, DERPMap: api.DERPMap,
@ -428,6 +426,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
_ = serverConn.Close() _ = serverConn.Close()
return nil, xerrors.Errorf("create tailnet conn: %w", err) return nil, xerrors.Errorf("create tailnet conn: %w", err)
} }
ctx, cancel := context.WithCancel(api.ctx)
conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn {
if !region.EmbeddedRelay { if !region.EmbeddedRelay {
return nil return nil
@ -437,7 +436,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
defer left.Close() defer left.Close()
defer right.Close() defer right.Close()
brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right))
api.DERPServer.Accept(ctx, right, brw, r.RemoteAddr) api.DERPServer.Accept(ctx, right, brw, "internal")
}() }()
return left return left
}) })
@ -453,6 +452,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
agentConn := &codersdk.WorkspaceAgentConn{ agentConn := &codersdk.WorkspaceAgentConn{
Conn: conn, Conn: conn,
CloseFunc: func() { CloseFunc: func() {
cancel()
_ = clientConn.Close() _ = clientConn.Close()
_ = serverConn.Close() _ = serverConn.Close()
}, },
@ -460,7 +460,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
go func() { go func() {
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID) err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil { if err != nil {
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err)) api.Logger.Warn(ctx, "tailnet coordinator client error", slog.Error(err))
_ = agentConn.Close() _ = agentConn.Close()
} }
}() }()

View File

@ -639,7 +639,7 @@ func (api *API) proxyWorkspaceApplication(rw http.ResponseWriter, r *http.Reques
}) })
} }
conn, release, err := api.workspaceAgentCache.Acquire(r, ticket.AgentID) conn, release, err := api.workspaceAgentCache.Acquire(ticket.AgentID)
if err != nil { if err != nil {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{ site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadGateway, Status: http.StatusBadGateway,

View File

@ -32,7 +32,7 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
} }
// Dialer creates a new agent connection by ID. // Dialer creates a new agent connection by ID.
type Dialer func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) type Dialer func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error)
// Conn wraps an agent connection with a reusable HTTP transport. // Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct { type Conn struct {
@ -78,7 +78,7 @@ type Cache struct {
// The returned function is used to release a lock on the connection. Once zero // The returned function is used to release a lock on the connection. Once zero
// locks exist on a connection, the inactive timeout will begin to tick down. // locks exist on a connection, the inactive timeout will begin to tick down.
// After the time expires, the connection will be cleared from the cache. // After the time expires, the connection will be cleared from the cache.
func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) { func (c *Cache) Acquire(id uuid.UUID) (*Conn, func(), error) {
rawConn, found := c.connMap.Load(id.String()) rawConn, found := c.connMap.Load(id.String())
// If the connection isn't found, establish a new one! // If the connection isn't found, establish a new one!
if !found { if !found {
@ -95,7 +95,7 @@ func (c *Cache) Acquire(r *http.Request, id uuid.UUID) (*Conn, func(), error) {
} }
c.closeGroup.Add(1) c.closeGroup.Add(1)
c.closeMutex.Unlock() c.closeMutex.Unlock()
agentConn, err := c.dialer(r, id) agentConn, err := c.dialer(id)
if err != nil { if err != nil {
c.closeGroup.Done() c.closeGroup.Done()
return nil, xerrors.Errorf("dial: %w", err) return nil, xerrors.Errorf("dial: %w", err)

View File

@ -40,33 +40,33 @@ func TestCache(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("Same", func(t *testing.T) { t.Run("Same", func(t *testing.T) {
t.Parallel() t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Metadata{}, 0), nil return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, 0) }, 0)
defer func() { defer func() {
_ = cache.Close() _ = cache.Close()
}() }()
conn1, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) conn1, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err) require.NoError(t, err)
conn2, _, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) conn2, _, err := cache.Acquire(uuid.Nil)
require.NoError(t, err) require.NoError(t, err)
require.True(t, conn1 == conn2) require.True(t, conn1 == conn2)
}) })
t.Run("Expire", func(t *testing.T) { t.Run("Expire", func(t *testing.T) {
t.Parallel() t.Parallel()
called := atomic.NewInt32(0) called := atomic.NewInt32(0)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
called.Add(1) called.Add(1)
return setupAgent(t, agentsdk.Metadata{}, 0), nil return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, time.Microsecond) }, time.Microsecond)
defer func() { defer func() {
_ = cache.Close() _ = cache.Close()
}() }()
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err) require.NoError(t, err)
release() release()
<-conn.Closed() <-conn.Closed()
conn, release, err = cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) conn, release, err = cache.Acquire(uuid.Nil)
require.NoError(t, err) require.NoError(t, err)
release() release()
<-conn.Closed() <-conn.Closed()
@ -74,13 +74,13 @@ func TestCache(t *testing.T) {
}) })
t.Run("NoExpireWhenLocked", func(t *testing.T) { t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel() t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Metadata{}, 0), nil return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, time.Microsecond) }, time.Microsecond)
defer func() { defer func() {
_ = cache.Close() _ = cache.Close()
}() }()
conn, release, err := cache.Acquire(httptest.NewRequest(http.MethodGet, "/", nil), uuid.Nil) conn, release, err := cache.Acquire(uuid.Nil)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
release() release()
@ -107,7 +107,7 @@ func TestCache(t *testing.T) {
}() }()
go server.Serve(random) go server.Serve(random)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) { cache := wsconncache.New(func(id uuid.UUID) (*codersdk.WorkspaceAgentConn, error) {
return setupAgent(t, agentsdk.Metadata{}, 0), nil return setupAgent(t, agentsdk.Metadata{}, 0), nil
}, time.Microsecond) }, time.Microsecond)
defer func() { defer func() {
@ -130,7 +130,7 @@ func TestCache(t *testing.T) {
defer cancel() defer cancel()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
conn, release, err := cache.Acquire(req, uuid.Nil) conn, release, err := cache.Acquire(uuid.Nil)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }