fix: Improve use of context in websocket.NetConn code paths (#6198)

This commit is contained in:
Mathias Fredriksson
2023-02-14 16:42:55 +02:00
committed by GitHub
parent 6fb8aff6d0
commit 5df7872661
5 changed files with 162 additions and 19 deletions

View File

@ -748,10 +748,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
}) })
return return
} }
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
go httpapi.Heartbeat(ctx, conn) go httpapi.Heartbeat(ctx, conn)
defer conn.Close(websocket.StatusNormalClosure, "") defer conn.Close(websocket.StatusNormalClosure, "")
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID) err = (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn, uuid.New(), workspaceAgent.ID)
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error()) _ = conn.Close(websocket.StatusInternalError, err.Error())
return return

View File

@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
return nil, codersdk.ReadBodyAsError(res) return nil, codersdk.ReadBodyAsError(res)
} }
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
// Ping once every 30 seconds to ensure that the websocket is alive. If we // Ping once every 30 seconds to ensure that the websocket is alive. If we
// don't get a response within 30s we kill the websocket and reconnect. // don't get a response within 30s we kill the websocket and reconnect.
// See: https://github.com/coder/coder/pull/5824 // See: https://github.com/coder/coder/pull/5824
@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
} }
}() }()
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil return wsNetConn, nil
} }
type PostAppHealthsRequest struct { type PostAppHealthsRequest struct {
@ -529,3 +531,44 @@ type closeFunc func() error
func (c closeFunc) Close() error { func (c closeFunc) Close() error {
return c() return c()
} }
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/url" "net/url"
@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
return nil, nil, ReadBodyAsError(res) return nil, nil, ReadBodyAsError(res)
} }
logs := make(chan ProvisionerJobLog) logs := make(chan ProvisionerJobLog)
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
closed := make(chan struct{}) closed := make(chan struct{})
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageText)
decoder := json.NewDecoder(wsNetConn)
go func() { go func() {
defer close(closed) defer close(closed)
defer close(logs) defer close(logs)
@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
} }
}() }()
return logs, closeFunc(func() error { return logs, closeFunc(func() error {
_ = conn.Close(websocket.StatusNormalClosure, "") _ = wsNetConn.Close()
<-closed <-closed
return nil return nil
}), nil }), nil
} }
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation. // ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
// implementation. The context is during dial, not during the lifetime of the
// client. Client should be closed after use.
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) { func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization)) serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
if err != nil { if err != nil {
@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U
config := yamux.DefaultConfig() config := yamux.DefaultConfig()
config.LogOutput = io.Discard config.LogOutput = io.Discard
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) // Use background context because caller should close the client.
_, wsNetConn := websocketNetConn(context.Background(), conn, websocket.MessageBinary)
session, err := yamux.Client(wsNetConn, config)
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
_ = wsNetConn.Close()
return nil, xerrors.Errorf("multiplex client: %w", err) return nil, xerrors.Errorf("multiplex client: %w", err)
} }
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil return proto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)), nil
} }
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
// @typescript-ignore wsNetConn
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}

View File

@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
} }
return nil, ReadBodyAsError(res) return nil, ReadBodyAsError(res)
} }
return websocket.NetConn(ctx, conn, websocket.MessageBinary), nil return websocket.NetConn(context.Background(), conn, websocket.MessageBinary), nil
} }
// WorkspaceAgentListeningPorts returns a list of ports that are currently being // WorkspaceAgentListeningPorts returns a list of ports that are currently being

View File

@ -1,11 +1,13 @@
package coderd package coderd
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"strings" "strings"
@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
// @Success 101 // @Success 101
// @Router /organizations/{organization}/provisionerdaemons/serve [get] // @Router /organizations/{organization}/provisionerdaemons/serve [get]
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) { func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tags := map[string]string{} tags := map[string]string{}
if r.URL.Query().Has("tag") { if r.URL.Query().Has("tag") {
for _, tag := range r.URL.Query()["tag"] { for _, tag := range r.URL.Query()["tag"] {
parts := strings.SplitN(tag, "=", 2) parts := strings.SplitN(tag, "=", 2)
if len(parts) < 2 { if len(parts) < 2 {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag), Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
}) })
return return
@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
} }
} }
if !r.URL.Query().Has("provisioner") { if !r.URL.Query().Has("provisioner") {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "The provisioner query parameter must be specified.", Message: "The provisioner query parameter must be specified.",
}) })
return return
@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
case string(codersdk.ProvisionerTypeTerraform): case string(codersdk.ProvisionerTypeTerraform):
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{} provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
default: default:
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner), Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
}) })
return return
@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization { if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) { if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "You aren't allowed to create provisioner daemons for the organization.", Message: "You aren't allowed to create provisioner daemons for the organization.",
}) })
return return
@ -155,7 +159,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
} }
name := namesgenerator.GetRandomName(1) name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(), ID: uuid.New(),
CreatedAt: database.Now(), CreatedAt: database.Now(),
Name: name, Name: name,
@ -163,7 +167,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
Tags: tags, Tags: tags,
}) })
if err != nil { if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error writing provisioner daemon.", Message: "Internal error writing provisioner daemon.",
Detail: err.Error(), Detail: err.Error(),
}) })
@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
rawTags, err := json.Marshal(daemon.Tags) rawTags, err := json.Marshal(daemon.Tags)
if err != nil { if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error marshaling daemon tags.", Message: "Internal error marshaling daemon tags.",
Detail: err.Error(), Detail: err.Error(),
}) })
@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
CompressionMode: websocket.CompressionDisabled, CompressionMode: websocket.CompressionDisabled,
}) })
if err != nil { if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error accepting websocket connection.", Message: "Internal error accepting websocket connection.",
Detail: err.Error(), Detail: err.Error(),
}) })
@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
// the same connection. // the same connection.
config := yamux.DefaultConfig() config := yamux.DefaultConfig()
config.LogOutput = io.Discard config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
session, err := yamux.Server(wsNetConn, config)
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
return return
@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
if xerrors.Is(err, io.EOF) { if xerrors.Is(err, io.EOF) {
return return
} }
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err)) api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
}, },
}) })
err = server.Serve(r.Context(), session) err = server.Serve(ctx, session)
if err != nil && !xerrors.Is(err, io.EOF) { if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err)) api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return return
} }
@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
} }
return result return result
} }
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
// is called if a read or write error is encountered.
type wsNetConn struct {
cancel context.CancelFunc
net.Conn
}
func (c *wsNetConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
if err != nil {
c.cancel()
}
return n, err
}
func (c *wsNetConn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
// websocketNetConn wraps websocket.NetConn and returns a context that
// is tied to the parent context and the lifetime of the conn. Any error
// during read or write will cancel the context, but not close the
// conn. Close should be called to release context resources.
func websocketNetConn(ctx context.Context, conn *websocket.Conn, msgType websocket.MessageType) (context.Context, net.Conn) {
ctx, cancel := context.WithCancel(ctx)
nc := websocket.NetConn(ctx, conn, msgType)
return ctx, &wsNetConn{
cancel: cancel,
Conn: nc,
}
}