mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
feat: Add peerbroker proxy for agent connections (#349)
* feat: Add peerbroker proxy for agent connections Agents will connect using this proxy. Eventually we'll intercept some of these messages for validation, but that's not necessary right now. * Add ASCII chart
This commit is contained in:
260
peerbroker/proxy.go
Normal file
260
peerbroker/proxy.go
Normal file
@ -0,0 +1,260 @@
|
||||
package peerbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
"golang.org/x/xerrors"
|
||||
protobuf "google.golang.org/protobuf/proto"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
// Each NegotiateConnection() function call spawns a new stream.
|
||||
streamIDLength = len(uuid.NewString())
|
||||
// We shouldn't PubSub anything larger than this!
|
||||
maxPayloadSizeBytes = 8192
|
||||
)
|
||||
|
||||
// ProxyOptions provides values to configure a proxy.
|
||||
type ProxyOptions struct {
|
||||
ChannelID string
|
||||
Logger slog.Logger
|
||||
Pubsub database.Pubsub
|
||||
}
|
||||
|
||||
// ProxyDial writes client negotiation streams over PubSub.
|
||||
//
|
||||
// PubSub is used to geodistribute WebRTC handshakes. All negotiation
|
||||
// messages are small in size (<=8KB), and we don't require delivery
|
||||
// guarantees because connections can always be renegotiated.
|
||||
// ┌────────────────────┐ ┌─────────────────────────────┐
|
||||
// │ coderd │ │ coderd │
|
||||
// ┌─────────────────────┐ │/<agent-id>/connect │ │ /<agent-id>/listen │
|
||||
// │ client │ │ │ │ │ ┌─────┐
|
||||
// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the <agent-id>│◄──┤agent│
|
||||
// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘
|
||||
// └─────────────────────┘ │<agent-id> channel: │ │from payloads to create new │
|
||||
// │ │ │NegotiateConnection() streams│
|
||||
// │<stream-id><payload>│ │or write to existing ones. │
|
||||
// └────────────────────┘ └─────────────────────────────┘
|
||||
func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
|
||||
proxyDial := &proxyDial{
|
||||
channelID: options.ChannelID,
|
||||
logger: options.Logger,
|
||||
pubsub: options.Pubsub,
|
||||
connection: client,
|
||||
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
|
||||
}
|
||||
return proxyDial, proxyDial.listen()
|
||||
}
|
||||
|
||||
// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
|
||||
// as new NegotiateConnection() streams.
|
||||
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
|
||||
mux := drpcmux.New()
|
||||
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
|
||||
channelID: options.ChannelID,
|
||||
pubsub: options.Pubsub,
|
||||
logger: options.Logger,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("register peer broker: %w", err)
|
||||
}
|
||||
server := drpcserver.New(mux)
|
||||
err = server.Serve(ctx, connListener)
|
||||
if err != nil {
|
||||
if errors.Is(err, yamux.ErrSessionShutdown) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("serve: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type proxyListen struct {
|
||||
channelID string
|
||||
pubsub database.Pubsub
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
|
||||
streamID := uuid.NewString()
|
||||
var err error
|
||||
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
|
||||
err := p.onServerToClientMessage(streamID, stream, message)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("subscribe: %w", err)
|
||||
}
|
||||
defer closeSubscribe()
|
||||
for {
|
||||
clientToServerMessage, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return xerrors.Errorf("recv: %w", err)
|
||||
}
|
||||
data, err := protobuf.Marshal(clientToServerMessage)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal: %w", err)
|
||||
}
|
||||
if len(data) > maxPayloadSizeBytes {
|
||||
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
|
||||
}
|
||||
data = append([]byte(streamID), data...)
|
||||
err = p.pubsub.Publish(proxyOutID(p.channelID), data)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
|
||||
if len(message) < streamIDLength {
|
||||
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
|
||||
}
|
||||
serverStreamID := string(message[0:streamIDLength])
|
||||
if serverStreamID != streamID {
|
||||
// It's not trying to communicate with this stream!
|
||||
return nil
|
||||
}
|
||||
var msg proto.NegotiateConnection_ServerToClient
|
||||
err := protobuf.Unmarshal(message[streamIDLength:], &msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("unmarshal message: %w", err)
|
||||
}
|
||||
err = stream.Send(&msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("send message: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type proxyDial struct {
|
||||
channelID string
|
||||
pubsub database.Pubsub
|
||||
logger slog.Logger
|
||||
|
||||
connection proto.DRPCPeerBrokerClient
|
||||
closeSubscribe func()
|
||||
streamMutex sync.Mutex
|
||||
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
|
||||
}
|
||||
|
||||
func (p *proxyDial) listen() error {
|
||||
var err error
|
||||
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
|
||||
err := p.onClientToServerMessage(ctx, message)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
|
||||
if len(message) < streamIDLength {
|
||||
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
|
||||
}
|
||||
var err error
|
||||
streamID := string(message[0:streamIDLength])
|
||||
p.streamMutex.Lock()
|
||||
stream, ok := p.streams[streamID]
|
||||
if !ok {
|
||||
stream, err = p.connection.NegotiateConnection(ctx)
|
||||
if err != nil {
|
||||
p.streamMutex.Unlock()
|
||||
return xerrors.Errorf("negotiate connection: %w", err)
|
||||
}
|
||||
p.streams[streamID] = stream
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
|
||||
err = p.onServerToClientMessage(streamID, stream)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
<-stream.Context().Done()
|
||||
p.streamMutex.Lock()
|
||||
delete(p.streams, streamID)
|
||||
p.streamMutex.Unlock()
|
||||
}()
|
||||
}
|
||||
p.streamMutex.Unlock()
|
||||
|
||||
var msg proto.NegotiateConnection_ClientToServer
|
||||
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("unmarshal message: %w", err)
|
||||
}
|
||||
err = stream.Send(&msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write message: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
|
||||
for {
|
||||
serverToClientMessage, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
break
|
||||
}
|
||||
return xerrors.Errorf("recv: %w", err)
|
||||
}
|
||||
data, err := protobuf.Marshal(serverToClientMessage)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal: %w", err)
|
||||
}
|
||||
if len(data) > maxPayloadSizeBytes {
|
||||
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
|
||||
}
|
||||
data = append([]byte(streamID), data...)
|
||||
err = p.pubsub.Publish(proxyInID(p.channelID), data)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("publish: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *proxyDial) Close() error {
|
||||
p.streamMutex.Lock()
|
||||
defer p.streamMutex.Unlock()
|
||||
p.closeSubscribe()
|
||||
return nil
|
||||
}
|
||||
|
||||
func proxyOutID(channelID string) string {
|
||||
return fmt.Sprintf("%s-out", channelID)
|
||||
}
|
||||
|
||||
func proxyInID(channelID string) string {
|
||||
return fmt.Sprintf("%s-in", channelID)
|
||||
}
|
81
peerbroker/proxy_test.go
Normal file
81
peerbroker/proxy_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package peerbroker_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/database"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
)
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
channelID := "hello"
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
dialerClient, dialerServer := provisionersdk.TransportPipe()
|
||||
defer dialerClient.Close()
|
||||
defer dialerServer.Close()
|
||||
listenerClient, listenerServer := provisionersdk.TransportPipe()
|
||||
defer listenerClient.Close()
|
||||
defer listenerServer.Close()
|
||||
|
||||
listener, err := peerbroker.Listen(listenerServer, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{
|
||||
ChannelID: channelID,
|
||||
Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug),
|
||||
Pubsub: pubsub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = proxyCloser.Close()
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{
|
||||
ChannelID: channelID,
|
||||
Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug),
|
||||
Pubsub: pubsub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient))
|
||||
stream, err := api.NegotiateConnection(ctx)
|
||||
require.NoError(t, err)
|
||||
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
}}, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
serverConn, err := listener.Accept()
|
||||
require.NoError(t, err)
|
||||
defer serverConn.Close()
|
||||
_, err = serverConn.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = clientConn.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = dialerServer.Close()
|
||||
wg.Wait()
|
||||
}
|
Reference in New Issue
Block a user