mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
fix: Avoid use of r.Context() after r.Hijack() (#1978)
This commit is contained in:
committed by
GitHub
parent
61aacff444
commit
b4f5920df5
@ -1,6 +1,7 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -16,6 +17,7 @@ import (
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
||||
session, err := yamux.Server(wsNetConn, config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
}
|
||||
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
|
||||
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
|
||||
ChannelID: workspaceAgent.ID.String(),
|
||||
Logger: api.Logger.Named("peerbroker-proxy-dial"),
|
||||
Pubsub: api.Pubsub,
|
||||
@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
||||
session, err := yamux.Server(wsNetConn, config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
|
||||
return
|
||||
@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
disconnectedAt := workspaceAgent.DisconnectedAt
|
||||
updateConnectionTimes := func() error {
|
||||
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
|
||||
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
|
||||
ID: workspaceAgent.ID,
|
||||
FirstConnectedAt: firstConnectedAt,
|
||||
LastConnectedAt: lastConnectedAt,
|
||||
@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
|
||||
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
|
||||
|
||||
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
||||
defer ticker.Stop()
|
||||
@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = wsConn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
netConn := websocket.NetConn(r.Context(), wsConn, websocket.MessageBinary)
|
||||
api.Logger.Debug(r.Context(), "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
|
||||
select {
|
||||
case <-api.TURNServer.Accept(netConn, remoteAddress, localAddress).Closed():
|
||||
case <-r.Context().Done():
|
||||
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
api.Logger.Debug(r.Context(), "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
|
||||
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
|
||||
}
|
||||
|
||||
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
|
||||
@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "ended")
|
||||
}()
|
||||
// Accept text connections, because it's more developer friendly.
|
||||
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID)
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
agentConn, err := api.dialWorkspaceAgent(ctx, r, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||
return
|
||||
@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.Copy(ptNetConn, wsNetConn)
|
||||
}
|
||||
|
||||
// dialWorkspaceAgent connects to a workspace agent by ID.
|
||||
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
|
||||
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
|
||||
// r.Context() for cancellation if it's use is safe or r.Hijack() has
|
||||
// not been performed.
|
||||
func (api *API) dialWorkspaceAgent(ctx context.Context, r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
go func() {
|
||||
_ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{
|
||||
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
|
||||
ChannelID: agentID.String(),
|
||||
Logger: api.Logger.Named("peerbroker-proxy-dial"),
|
||||
Pubsub: api.Pubsub,
|
||||
@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
|
||||
}()
|
||||
|
||||
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
||||
stream, err := peerClient.NegotiateConnection(r.Context())
|
||||
stream, err := peerClient.NegotiateConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("negotiate: %w", err)
|
||||
}
|
||||
@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
|
||||
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
|
||||
clientPipe, serverPipe := net.Pipe()
|
||||
go func() {
|
||||
<-r.Context().Done()
|
||||
<-ctx.Done()
|
||||
_ = clientPipe.Close()
|
||||
_ = serverPipe.Close()
|
||||
}()
|
||||
@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
|
||||
|
||||
return workspaceAgent, nil
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user