mirror of
https://github.com/coder/coder.git
synced 2025-07-23 21:32:07 +00:00
chore: apply the 4mb max limit on drpc protocol message size (#17771)
Respect the 4mb max limit on proto messages
This commit is contained in:
@@ -60,6 +60,7 @@ func NewClient(t testing.TB,
|
||||
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
|
||||
require.NoError(t, err)
|
||||
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return
|
||||
|
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
@@ -209,6 +210,7 @@ func (a *API) Server(ctx context.Context) (*drpcserver.Server, error) {
|
||||
|
||||
return drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
|
||||
drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return
|
||||
|
@@ -38,6 +38,7 @@ import (
|
||||
"tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
@@ -84,7 +85,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
@@ -1803,6 +1803,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
|
||||
}
|
||||
server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux},
|
||||
drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcconn"
|
||||
"storj.io/drpc/drpcmanager"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
)
|
||||
@@ -19,6 +20,17 @@ const (
|
||||
MaxMessageSize = 4 << 20
|
||||
)
|
||||
|
||||
func DefaultDRPCOptions(options *drpcmanager.Options) drpcmanager.Options {
|
||||
if options == nil {
|
||||
options = &drpcmanager.Options{}
|
||||
}
|
||||
|
||||
if options.Reader.MaximumBufferSize == 0 {
|
||||
options.Reader.MaximumBufferSize = MaxMessageSize
|
||||
}
|
||||
return *options
|
||||
}
|
||||
|
||||
// MultiplexedConn returns a multiplexed dRPC connection from a yamux Session.
|
||||
func MultiplexedConn(session *yamux.Session) drpc.Conn {
|
||||
return &multiplexedDRPC{session}
|
||||
@@ -43,7 +55,9 @@ func (m *multiplexedDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encod
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dConn := drpcconn.New(conn)
|
||||
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
|
||||
Manager: DefaultDRPCOptions(nil),
|
||||
})
|
||||
defer func() {
|
||||
_ = dConn.Close()
|
||||
}()
|
||||
@@ -55,7 +69,9 @@ func (m *multiplexedDRPC) NewStream(ctx context.Context, rpc string, enc drpc.En
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dConn := drpcconn.New(conn)
|
||||
dConn := drpcconn.NewWithOptions(conn, drpcconn.Options{
|
||||
Manager: DefaultDRPCOptions(nil),
|
||||
})
|
||||
stream, err := dConn.NewStream(ctx, rpc, enc)
|
||||
if err == nil {
|
||||
go func() {
|
||||
@@ -97,7 +113,9 @@ func (m *memDRPC) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, inM
|
||||
return err
|
||||
}
|
||||
|
||||
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
|
||||
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
|
||||
Manager: DefaultDRPCOptions(nil),
|
||||
})}
|
||||
defer func() {
|
||||
_ = dConn.Close()
|
||||
_ = conn.Close()
|
||||
@@ -110,7 +128,9 @@ func (m *memDRPC) NewStream(ctx context.Context, rpc string, enc drpc.Encoding)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dConn := &tracing.DRPCConn{Conn: drpcconn.New(conn)}
|
||||
dConn := &tracing.DRPCConn{Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
|
||||
Manager: DefaultDRPCOptions(nil),
|
||||
})}
|
||||
stream, err := dConn.NewStream(ctx, rpc, enc)
|
||||
if err != nil {
|
||||
_ = dConn.Close()
|
||||
|
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
"github.com/coder/coder/v2/provisionersdk"
|
||||
"github.com/coder/websocket"
|
||||
@@ -370,6 +371,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
return
|
||||
|
@@ -27,6 +27,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
agpl "github.com/coder/coder/v2/provisionerd"
|
||||
"github.com/coder/coder/v2/provisionerd/proto"
|
||||
@@ -188,8 +189,10 @@ func (r *remoteConnector) handleConn(conn net.Conn) {
|
||||
logger.Info(r.ctx, "provisioner connected")
|
||||
closeConn = false // we're passing the conn over the channel
|
||||
w.respCh <- agpl.ConnectResponse{
|
||||
Job: w.job,
|
||||
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.New(tlsConn)),
|
||||
Job: w.job,
|
||||
Client: sdkproto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(tlsConn, drpcconn.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -178,6 +178,79 @@ func TestProvisionerd(t *testing.T) {
|
||||
require.NoError(t, closer.Close())
|
||||
})
|
||||
|
||||
// LargePayloads sends a 3mb tar file to the provisioner. The provisioner also
|
||||
// returns large payload messages back. The limit should be 4mb, so all
|
||||
// these messages should work.
|
||||
t.Run("LargePayloads", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(done)
|
||||
})
|
||||
var (
|
||||
largeSize = 3 * 1024 * 1024
|
||||
completeChan = make(chan struct{})
|
||||
completeOnce sync.Once
|
||||
acq = newAcquireOne(t, &proto.AcquiredJob{
|
||||
JobId: "test",
|
||||
Provisioner: "someprovisioner",
|
||||
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
|
||||
"toolarge.txt": string(make([]byte, largeSize)),
|
||||
}),
|
||||
Type: &proto.AcquiredJob_TemplateImport_{
|
||||
TemplateImport: &proto.AcquiredJob_TemplateImport{
|
||||
Metadata: &sdkproto.Metadata{},
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
||||
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
|
||||
acquireJobWithCancel: acq.acquireWithCancel,
|
||||
updateJob: noopUpdateJob,
|
||||
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
|
||||
completeOnce.Do(func() { close(completeChan) })
|
||||
return &proto.Empty{}, nil
|
||||
},
|
||||
}), nil
|
||||
}, provisionerd.LocalProvisioners{
|
||||
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
|
||||
parse: func(
|
||||
s *provisionersdk.Session,
|
||||
_ *sdkproto.ParseRequest,
|
||||
cancelOrComplete <-chan struct{},
|
||||
) *sdkproto.ParseComplete {
|
||||
return &sdkproto.ParseComplete{
|
||||
// 6mb readme
|
||||
Readme: make([]byte, largeSize),
|
||||
}
|
||||
},
|
||||
plan: func(
|
||||
_ *provisionersdk.Session,
|
||||
_ *sdkproto.PlanRequest,
|
||||
_ <-chan struct{},
|
||||
) *sdkproto.PlanComplete {
|
||||
return &sdkproto.PlanComplete{
|
||||
Resources: []*sdkproto.Resource{},
|
||||
Plan: make([]byte, largeSize),
|
||||
}
|
||||
},
|
||||
apply: func(
|
||||
_ *provisionersdk.Session,
|
||||
_ *sdkproto.ApplyRequest,
|
||||
_ <-chan struct{},
|
||||
) *sdkproto.ApplyComplete {
|
||||
return &sdkproto.ApplyComplete{
|
||||
State: make([]byte, largeSize),
|
||||
}
|
||||
},
|
||||
}),
|
||||
})
|
||||
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
|
||||
require.NoError(t, closer.Close())
|
||||
})
|
||||
|
||||
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
done := make(chan struct{})
|
||||
@@ -1115,7 +1188,9 @@ func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server pr
|
||||
mux := drpcmux.New()
|
||||
err := proto.DRPCRegisterProvisionerDaemon(mux, &server)
|
||||
require.NoError(t, err)
|
||||
srv := drpcserver.New(mux)
|
||||
srv := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
})
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
|
@@ -15,6 +15,7 @@ import (
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
@@ -81,7 +82,9 @@ func Serve(ctx context.Context, server Server, options *ServeOptions) error {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("register provisioner: %w", err)
|
||||
}
|
||||
srv := drpcserver.New(&tracing.DRPCHandler{Handler: mux})
|
||||
srv := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
})
|
||||
|
||||
if options.Listener != nil {
|
||||
err = srv.Serve(ctx, options.Listener)
|
||||
|
@@ -94,7 +94,9 @@ func TestProvisionerSDK(t *testing.T) {
|
||||
srvErr <- err
|
||||
}()
|
||||
|
||||
api := proto.NewDRPCProvisionerClient(drpcconn.New(client))
|
||||
api := proto.NewDRPCProvisionerClient(drpcconn.NewWithOptions(client, drpcconn.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
}))
|
||||
s, err := api.Session(ctx)
|
||||
require.NoError(t, err)
|
||||
err = s.Send(&proto.Request{Type: &proto.Request_Config{Config: &proto.Config{}}})
|
||||
|
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
@@ -92,6 +93,7 @@ func NewClientService(options ClientServiceOptions) (
|
||||
return nil, xerrors.Errorf("register DRPC service: %w", err)
|
||||
}
|
||||
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if xerrors.Is(err, io.EOF) ||
|
||||
xerrors.Is(err, context.Canceled) ||
|
||||
|
Reference in New Issue
Block a user