fix: Improve use of context in websocket.NetConn code paths (#6198)

This commit is contained in:
Mathias Fredriksson
2023-02-14 16:42:55 +02:00
committed by GitHub
parent 6fb8aff6d0
commit 5df7872661
5 changed files with 162 additions and 19 deletions

View File

@ -748,10 +748,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
})
return
}
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
go httpapi.Heartbeat(ctx, conn)
defer conn.Close(websocket.StatusNormalClosure, "")
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return

View File

@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
return nil, codersdk.ReadBodyAsError(res)
}
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
// Ping once every 30 seconds to ensure that the websocket is alive. If we
// don't get a response within 30s we kill the websocket and reconnect.
// See: https://github.com/coder/coder/pull/5824
@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
}
}()
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
return wsNetConn, nil
}
type PostAppHealthsRequest struct {
@ -529,3 +531,44 @@ type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
return nil, nil, ReadBodyAsError(res)
}
logs := make(chan ProvisionerJobLog)
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
closed := make(chan struct{})
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
decoder := json.NewDecoder(wsNetConn)
go func() {
defer close(closed)
defer close(logs)
@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
}
}()
return logs, closeFunc(func() error {
_ = conn.Close(websocket.StatusNormalClosure, "")
_ = wsNetConn.Close()
<-closed
return nil
}), nil
}
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
// implementation. The context is during dial, not during the lifetime of the
// client. Client should be closed after use.
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
if err != nil {
@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
// Use background context because caller should close the client.
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
session, err := yamux.Client(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
_ = wsNetConn.Close()
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
}
return nil, ReadBodyAsError(res)
}
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil
return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil
}
// WorkspaceAgentListeningPorts returns a list of ports that are currently being

View File

@ -1,11 +1,13 @@
package coderd
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
// @Success 101
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tags := map[string]string{}
if r.URL.Query().Has("tag") {
for _, tag := range r.URL.Query()["tag"] {
parts := strings.SplitN(tag, "=", 2)
if len(parts) < 2 {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
})
return
@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}
}
if !r.URL.Query().Has("provisioner") {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "The provisioner query parameter must be specified.",
})
return
@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
case string(codersdk.ProvisionerTypeTerraform):
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
default:
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
})
return
@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "You aren't allowed to create provisioner daemons for the organization.",
})
return
@ -155,7 +159,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}
name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
@ -163,7 +167,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
Tags: tags,
})
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error writing provisioner daemon.",
Detail: err.Error(),
})
@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
rawTags, err := json.Marshal(daemon.Tags)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error marshaling daemon tags.",
Detail: err.Error(),
})
@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error accepting websocket connection.",
Detail: err.Error(),
})
@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
// the same connection.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
session, err := yamux.Server(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
return
@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
err = server.Serve(r.Context(), session)
err = server.Serve(ctx, session)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
}
return result
}
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}