mirror of
https://github.com/coder/coder.git
synced 2025-07-21 01:28:49 +00:00
fix: Improve use of context in websocket.NetConn
code paths (#6198)
This commit is contained in:
committed by
GitHub
parent
6fb8aff6d0
commit
5df7872661
@ -748,10 +748,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||||
|
defer wsNetConn.Close()
|
||||||
|
|
||||||
go httpapi.Heartbeat(ctx, conn)
|
go httpapi.Heartbeat(ctx, conn)
|
||||||
|
|
||||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||||
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
|
err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||||
return
|
return
|
||||||
|
@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
|
|||||||
return nil, codersdk.ReadBodyAsError(res)
|
return nil, codersdk.ReadBodyAsError(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||||
|
|
||||||
// Ping once every 30 seconds to ensure that the websocket is alive. If we
|
// Ping once every 30 seconds to ensure that the websocket is alive. If we
|
||||||
// don't get a response within 30s we kill the websocket and reconnect.
|
// don't get a response within 30s we kill the websocket and reconnect.
|
||||||
// See: https://github.com/coder/coder/pull/5824
|
// See: https://github.com/coder/coder/pull/5824
|
||||||
@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
return wsNetConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type PostAppHealthsRequest struct {
|
type PostAppHealthsRequest struct {
|
||||||
@ -529,3 +531,44 @@ type closeFunc func() error
|
|||||||
func (c closeFunc) Close() error {
|
func (c closeFunc) Close() error {
|
||||||
return c()
|
return c()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||||
|
// is called if a read or write error is encountered.
|
||||||
|
type wsNetConn struct {
|
||||||
|
cancel context.CancelFunc
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||||
|
n, err = c.Conn.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||||
|
n, err = c.Conn.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Close() error {
|
||||||
|
defer c.cancel()
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||||
|
// is tied to the parent context and the lifetime of the conn. Any error
|
||||||
|
// during read or write will cancel the context, but not close the
|
||||||
|
// conn. Close should be called to release context resources.
|
||||||
|
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
nc := websocket.NetConn(ctx, conn, msgType)
|
||||||
|
return ctx, &wsNetConn{
|
||||||
|
cancel: cancel,
|
||||||
|
Conn: nc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
|
|||||||
return nil, nil, ReadBodyAsError(res)
|
return nil, nil, ReadBodyAsError(res)
|
||||||
}
|
}
|
||||||
logs := make(chan ProvisionerJobLog)
|
logs := make(chan ProvisionerJobLog)
|
||||||
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
|
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
|
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||||
|
decoder := json.NewDecoder(wsNetConn)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(closed)
|
defer close(closed)
|
||||||
defer close(logs)
|
defer close(logs)
|
||||||
@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return logs, closeFunc(func() error {
|
return logs, closeFunc(func() error {
|
||||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
_ = wsNetConn.Close()
|
||||||
<-closed
|
<-closed
|
||||||
return nil
|
return nil
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
|
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
|
||||||
|
// implementation. The context is during dial, not during the lifetime of the
|
||||||
|
// client. Client should be closed after use.
|
||||||
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
|
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
|
||||||
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
|
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U
|
|||||||
|
|
||||||
config := yamux.DefaultConfig()
|
config := yamux.DefaultConfig()
|
||||||
config.LogOutput = io.Discard
|
config.LogOutput = io.Discard
|
||||||
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
|
// Use background context because caller should close the client.
|
||||||
|
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
|
||||||
|
session, err := yamux.Client(wsNetConn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = conn.Close(websocket.StatusGoingAway, "")
|
||||||
|
_ = wsNetConn.Close()
|
||||||
return nil, xerrors.Errorf("multiplex client: %w", err)
|
return nil, xerrors.Errorf("multiplex client: %w", err)
|
||||||
}
|
}
|
||||||
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
|
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||||
|
// is called if a read or write error is encountered.
|
||||||
|
// @typescript-ignore wsNetConn
|
||||||
|
type wsNetConn struct {
|
||||||
|
cancel context.CancelFunc
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||||
|
n, err = c.Conn.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||||
|
n, err = c.Conn.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Close() error {
|
||||||
|
defer c.cancel()
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||||
|
// is tied to the parent context and the lifetime of the conn. Any error
|
||||||
|
// during read or write will cancel the context, but not close the
|
||||||
|
// conn. Close should be called to release context resources.
|
||||||
|
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
nc := websocket.NetConn(ctx, conn, msgType)
|
||||||
|
return ctx, &wsNetConn{
|
||||||
|
cancel: cancel,
|
||||||
|
Conn: nc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
|
|||||||
}
|
}
|
||||||
return nil, ReadBodyAsError(res)
|
return nil, ReadBodyAsError(res)
|
||||||
}
|
}
|
||||||
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
|
return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WorkspaceAgentListeningPorts returns a list of ports that are currently being
|
// WorkspaceAgentListeningPorts returns a list of ports that are currently being
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
package coderd
|
package coderd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
|
|||||||
// @Success 101
|
// @Success 101
|
||||||
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
|
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
|
||||||
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
tags := map[string]string{}
|
tags := map[string]string{}
|
||||||
if r.URL.Query().Has("tag") {
|
if r.URL.Query().Has("tag") {
|
||||||
for _, tag := range r.URL.Query()["tag"] {
|
for _, tag := range r.URL.Query()["tag"] {
|
||||||
parts := strings.SplitN(tag, "=", 2)
|
parts := strings.SplitN(tag, "=", 2)
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
|
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !r.URL.Query().Has("provisioner") {
|
if !r.URL.Query().Has("provisioner") {
|
||||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: "The provisioner query parameter must be specified.",
|
Message: "The provisioner query parameter must be specified.",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
case string(codersdk.ProvisionerTypeTerraform):
|
case string(codersdk.ProvisionerTypeTerraform):
|
||||||
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
|
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
|
||||||
default:
|
default:
|
||||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
|
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
|
|
||||||
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
|
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
|
||||||
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
|
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
|
||||||
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||||
Message: "You aren't allowed to create provisioner daemons for the organization.",
|
Message: "You aren't allowed to create provisioner daemons for the organization.",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@ -155,7 +159,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
name := namesgenerator.GetRandomName(1)
|
name := namesgenerator.GetRandomName(1)
|
||||||
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
|
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
CreatedAt: database.Now(),
|
CreatedAt: database.Now(),
|
||||||
Name: name,
|
Name: name,
|
||||||
@ -163,7 +167,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
Tags: tags,
|
Tags: tags,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: "Internal error writing provisioner daemon.",
|
Message: "Internal error writing provisioner daemon.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
|
|
||||||
rawTags, err := json.Marshal(daemon.Tags)
|
rawTags, err := json.Marshal(daemon.Tags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||||
Message: "Internal error marshaling daemon tags.",
|
Message: "Internal error marshaling daemon tags.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
CompressionMode: websocket.CompressionDisabled,
|
CompressionMode: websocket.CompressionDisabled,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: "Internal error accepting websocket connection.",
|
Message: "Internal error accepting websocket connection.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
// the same connection.
|
// the same connection.
|
||||||
config := yamux.DefaultConfig()
|
config := yamux.DefaultConfig()
|
||||||
config.LogOutput = io.Discard
|
config.LogOutput = io.Discard
|
||||||
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
|
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||||
|
defer wsNetConn.Close()
|
||||||
|
session, err := yamux.Server(wsNetConn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
|
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
|
||||||
return
|
return
|
||||||
@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
|||||||
if xerrors.Is(err, io.EOF) {
|
if xerrors.Is(err, io.EOF) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
|
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
err = server.Serve(r.Context(), session)
|
err = server.Serve(ctx, session)
|
||||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||||
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
|
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
|
||||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
|
||||||
|
// is called if a read or write error is encountered.
|
||||||
|
type wsNetConn struct {
|
||||||
|
cancel context.CancelFunc
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Read(b []byte) (n int, err error) {
|
||||||
|
n, err = c.Conn.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Write(b []byte) (n int, err error) {
|
||||||
|
n, err = c.Conn.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *wsNetConn) Close() error {
|
||||||
|
defer c.cancel()
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// websocketNetConn wraps websocket.NetConn and returns a context that
|
||||||
|
// is tied to the parent context and the lifetime of the conn. Any error
|
||||||
|
// during read or write will cancel the context, but not close the
|
||||||
|
// conn. Close should be called to release context resources.
|
||||||
|
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
nc := websocket.NetConn(ctx, conn, msgType)
|
||||||
|
return ctx, &wsNetConn{
|
||||||
|
cancel: cancel,
|
||||||
|
Conn: nc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user