mirror of
https://github.com/coder/coder.git
synced 2025-07-21 01:28:49 +00:00
fix: Improve use of context in websocket.NetConn
code paths (#6198)
This commit is contained in:
committed by
GitHub
parent
6fb8aff6d0
commit
5df7872661
@ -748,10 +748,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
||||
})
|
||||
return
|
||||
}
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.Heartbeat(ctx, conn)
|
||||
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
|
||||
err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||
return
|
||||
|
@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
|
||||
return nil, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
|
||||
// Ping once every 30 seconds to ensure that the websocket is alive. If we
|
||||
// don't get a response within 30s we kill the websocket and reconnect.
|
||||
// See: https://github.com/coder/coder/pull/5824
|
||||
@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
return wsNetConn, nil
|
||||
}
|
||||
|
||||
type PostAppHealthsRequest struct {
|
||||
@ -529,3 +531,44 @@ type closeFunc func() error
|
||||
func (c closeFunc) Close() error {
|
||||
return c()
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
|
||||
return nil, nil, ReadBodyAsError(res)
|
||||
}
|
||||
logs := make(chan ProvisionerJobLog)
|
||||
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
|
||||
closed := make(chan struct{})
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||
decoder := json.NewDecoder(wsNetConn)
|
||||
go func() {
|
||||
defer close(closed)
|
||||
defer close(logs)
|
||||
@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
|
||||
}
|
||||
}()
|
||||
return logs, closeFunc(func() error {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
_ = wsNetConn.Close()
|
||||
<-closed
|
||||
return nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
|
||||
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
|
||||
// implementation. The context is during dial, not during the lifetime of the
|
||||
// client. Client should be closed after use.
|
||||
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
|
||||
if err != nil {
|
||||
@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
|
||||
// Use background context because caller should close the client.
|
||||
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
|
||||
session, err := yamux.Client(wsNetConn, config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||
_ = wsNetConn.Close()
|
||||
return nil, xerrors.Errorf("multiplex client: %w", err)
|
||||
}
|
||||
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
// @typescript-ignore wsNetConn
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
|
||||
}
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
||||
return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil
|
||||
}
|
||||
|
||||
// WorkspaceAgentListeningPorts returns a list of ports that are currently being
|
||||
|
@ -1,11 +1,13 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Success 101
|
||||
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
|
||||
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
tags := map[string]string{}
|
||||
if r.URL.Query().Has("tag") {
|
||||
for _, tag := range r.URL.Query()["tag"] {
|
||||
parts := strings.SplitN(tag, "=", 2)
|
||||
if len(parts) < 2 {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
|
||||
})
|
||||
return
|
||||
@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
}
|
||||
if !r.URL.Query().Has("provisioner") {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "The provisioner query parameter must be specified.",
|
||||
})
|
||||
return
|
||||
@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
case string(codersdk.ProvisionerTypeTerraform):
|
||||
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
|
||||
default:
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
|
||||
})
|
||||
return
|
||||
@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
|
||||
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "You aren't allowed to create provisioner daemons for the organization.",
|
||||
})
|
||||
return
|
||||
@ -155,7 +159,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
name := namesgenerator.GetRandomName(1)
|
||||
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
|
||||
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
|
||||
ID: uuid.New(),
|
||||
CreatedAt: database.Now(),
|
||||
Name: name,
|
||||
@ -163,7 +167,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
Tags: tags,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error writing provisioner daemon.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
rawTags, err := json.Marshal(daemon.Tags)
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error marshaling daemon tags.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Internal error accepting websocket connection.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
// the same connection.
|
||||
config := yamux.DefaultConfig()
|
||||
config.LogOutput = io.Discard
|
||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
session, err := yamux.Server(wsNetConn, config)
|
||||
if err != nil {
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
|
||||
return
|
||||
@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
|
||||
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
|
||||
},
|
||||
})
|
||||
err = server.Serve(r.Context(), session)
|
||||
err = server.Serve(ctx, session)
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
|
||||
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
||||
return
|
||||
}
|
||||
@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||
// is called if a read or write error is encountered.
|
||||
type wsNetConn struct {
|
||||
cancel context.CancelFunc
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *wsNetConn) Close() error {
|
||||
defer c.cancel()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||
// is tied to the parent context and the lifetime of the conn. Any error
|
||||
// during read or write will cancel the context, but not close the
|
||||
// conn. Close should be called to release context resources.
|
||||
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
nc := websocket.NetConn(ctx, conn, msgType)
|
||||
return ctx, &wsNetConn{
|
||||
cancel: cancel,
|
||||
Conn: nc,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user