chore: apply the 4mb max limit on drpc protocol message size (#17771)

Respect the 4mb max limit on proto messages
This commit is contained in:
Steven Masley
2025-05-13 11:24:51 -05:00
committed by GitHub
parent a1c03b6c5f
commit 64807e1d61
10 changed files with 121 additions and 10 deletions

View File

@@ -60,6 +60,7 @@ func NewClient(t testing.TB,
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return

View File

@@ -30,6 +30,7 @@ import (
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/tailnet"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
@@ -209,6 +210,7 @@ func (a *API) Server(ctx context.Context) (*drpcserver.Server, error) {
return drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return

View File

@@ -38,6 +38,7 @@ import (
"tailscale.com/util/singleflight"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/quartz"
"github.com/coder/serpent"
@@ -84,7 +85,6 @@ import (
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
@@ -1803,6 +1803,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
}
server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return

View File

@@ -9,6 +9,7 @@ import (
"github.com/valyala/fasthttp/fasthttputil"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmanager"
"github.com/coder/coder/v2/coderd/tracing"
)
@@ -19,6 +20,17 @@ const (
MaxMessageSize = 4 << 20
)
func DefaultDRPCOptions(options *drpcmanager.Options) drpcmanager.Options {
if options == nil {
options = &drpcmanager.Options{}
}
if options.Reader.MaximumBufferSize == 0 {
options.Reader.MaximumBufferSize = MaxMessageSize
}
return *options
}
// MultiplexedConn returns a multiplexed dRPC connection from a yamux Session.
func MultiplexedConn(session *yamux.Session) drpc.Conn {
return &multiplexedDRPC{session}
@@ -43,7 +55,9 @@ func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encod
if err != nil {
return err
}
dConn := drpcconn.New(conn)
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})
defer func() {
_ = dConn.Close()
}()
@@ -55,7 +69,9 @@ func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.En
if err != nil {
return nil, err
}
dConn := drpcconn.New(conn)
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})
stream, err := dConn.NewStream(ctx, rpc, enc)
if err == nil {
go func() {
@@ -97,7 +113,9 @@ func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inM
return err
}
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})}
defer func() {
_ = dConn.Close()
_ = conn.Close()
@@ -110,7 +128,9 @@ func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding)
if err != nil {
return nil, err
}
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: DefaultDRPCOptions(nil),
})}
stream, err := dConn.NewStream(ctx, rpc, enc)
if err != nil {
_ = dConn.Close()

View File

@@ -31,6 +31,7 @@ import (
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/websocket"
@@ -370,6 +371,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
return
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return

View File

@@ -27,6 +27,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/provisioner/echo"
agpl "github.com/coder/coder/v2/provisionerd"
"github.com/coder/coder/v2/provisionerd/proto"
@@ -188,8 +189,10 @@ func (r *remoteConnector) handleConn(conn net.Conn) {
logger.Info(r.ctx, "provisioner connected")
closeConn = false // we're passing the conn over the channel
w.respCh <- agpl.ConnectResponse{
Job: w.job,
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)),
Job: w.job,
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(tlsConn, drpcconn.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
})),
}
}

View File

@@ -178,6 +178,79 @@ func TestProvisionerd(t *testing.T) {
require.NoError(t, closer.Close())
})
// LargePayloads sends a 3mb tar file to the provisioner. The provisioner also
// returns large payload messages back. The limit should be 4mb, so all
// these messages should work.
t.Run("LargePayloads", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
largeSize = 3 * 1024 * 1024
completeChan = make(chan struct{})
completeOnce sync.Once
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"toolarge.txt": string(make([]byte, largeSize)),
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{},
},
},
})
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: noopUpdateJob,
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(
s *provisionersdk.Session,
_ *sdkproto.ParseRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.ParseComplete {
return &sdkproto.ParseComplete{
// 6mb readme
Readme: make([]byte, largeSize),
}
},
plan: func(
_ *provisionersdk.Session,
_ *sdkproto.PlanRequest,
_ <-chan struct{},
) *sdkproto.PlanComplete {
return &sdkproto.PlanComplete{
Resources: []*sdkproto.Resource{},
Plan: make([]byte, largeSize),
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
return &sdkproto.ApplyComplete{
State: make([]byte, largeSize),
}
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
@@ -1115,7 +1188,9 @@ func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server pr
mux := drpcmux.New()
err := proto.DRPCRegisterProvisionerDaemon(mux, &server)
require.NoError(t, err)
srv := drpcserver.New(mux)
srv := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
})
ctx, cancelFunc := context.WithCancel(context.Background())
closed := make(chan struct{})
go func() {

View File

@@ -15,6 +15,7 @@ import (
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/provisionersdk/proto"
@@ -81,7 +82,9 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
if err != nil {
return xerrors.Errorf("register provisioner: %w", err)
}
srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux})
srv := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
})
if options.Listener != nil {
err = srv.Serve(ctx, options.Listener)

View File

@@ -94,7 +94,9 @@ func TestProvisionerSDK(t *testing.T) {
srvErr <- err
}()
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
api := proto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(client, drpcconn.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
}))
s, err := api.Session(ctx)
require.NoError(t, err)
err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}})

View File

@@ -17,6 +17,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/quartz"
)
@@ -92,6 +93,7 @@ func NewClientService(options ClientServiceOptions) (
return nil, xerrors.Errorf("register DRPC service: %w", err)
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
Log: func(err error) {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, context.Canceled) ||