Files
coder/vpn/serdes.go

140 lines
3.4 KiB
Go

package vpn
import (
"context"
"encoding/binary"
"io"
"sync"
"google.golang.org/protobuf/proto"
"cdr.dev/slog"
)
// MaxLength is the largest possible CoderVPN Protocol message size. This is set
// so that a misbehaving peer can't cause us to allocate a huge amount of memory.
const MaxLength = 0x1000000 // 16MiB
// serdes SERializes and DESerializes protobuf messages to and from the conn.
type serdes[S rpcMessage, R receivableRPCMessage[RR], RR any] struct {
ctx context.Context
logger slog.Logger
conn io.ReadWriteCloser
sendCh <-chan S
recvCh chan<- R
closeOnce sync.Once
wg sync.WaitGroup
}
func (s *serdes[_, R, RR]) recvLoop() {
s.logger.Debug(s.ctx, "starting recvLoop")
defer s.closeIdempotent()
defer close(s.recvCh)
for {
var length uint32
if err := binary.Read(s.conn, binary.BigEndian, &length); err != nil {
s.logger.Debug(s.ctx, "failed to read length", slog.Error(err))
return
}
if length > MaxLength {
s.logger.Critical(s.ctx, "message length exceeds max",
slog.F("length", length))
return
}
s.logger.Debug(s.ctx, "about to read message", slog.F("length", length))
mb := make([]byte, length)
if n, err := io.ReadFull(s.conn, mb); err != nil {
s.logger.Debug(s.ctx, "failed to read message",
slog.Error(err),
slog.F("num_bytes_read", n))
return
}
msg := R(new(RR))
if err := proto.Unmarshal(mb, msg); err != nil {
s.logger.Critical(s.ctx, "failed to unmarshal message", slog.Error(err))
return
}
select {
case s.recvCh <- msg:
s.logger.Debug(s.ctx, "passed received message to speaker")
case <-s.ctx.Done():
s.logger.Debug(s.ctx, "recvLoop canceled", slog.Error(s.ctx.Err()))
}
}
}
func (s *serdes[S, _, _]) sendLoop() {
s.logger.Debug(s.ctx, "starting sendLoop")
defer s.closeIdempotent()
for {
select {
case <-s.ctx.Done():
s.logger.Debug(s.ctx, "sendLoop canceled", slog.Error(s.ctx.Err()))
return
case msg, ok := <-s.sendCh:
if !ok {
s.logger.Debug(s.ctx, "sendCh closed")
return
}
mb, err := proto.Marshal(msg)
if err != nil {
s.logger.Critical(s.ctx, "failed to marshal message", slog.Error(err))
return
}
if err := binary.Write(s.conn, binary.BigEndian, uint32(len(mb))); err != nil {
s.logger.Debug(s.ctx, "failed to write length", slog.Error(err))
return
}
if _, err := s.conn.Write(mb); err != nil {
s.logger.Debug(s.ctx, "failed to write message", slog.Error(err))
return
}
}
}
}
func (s *serdes[_, _, _]) closeIdempotent() {
s.closeOnce.Do(func() {
if err := s.conn.Close(); err != nil {
s.logger.Error(s.ctx, "failed to close connection", slog.Error(err))
} else {
s.logger.Info(s.ctx, "closed connection")
}
})
}
// Close closes the serdes
// nolint: revive
func (s *serdes[_, _, _]) Close() error {
s.closeIdempotent()
s.wg.Wait()
return nil
}
// start starts the goroutines that serialize and deserialize to the conn.
// nolint: revive
func (s *serdes[_, _, _]) start() {
s.wg.Add(2)
go func() {
defer s.wg.Done()
s.recvLoop()
}()
go func() {
defer s.wg.Done()
s.sendLoop()
}()
}
func newSerdes[S rpcMessage, R receivableRPCMessage[RR], RR any](
ctx context.Context, logger slog.Logger, conn io.ReadWriteCloser,
sendCh <-chan S, recvCh chan<- R,
) *serdes[S, R, RR] {
return &serdes[S, R, RR]{
ctx: ctx,
logger: logger.Named("serdes"),
conn: conn,
sendCh: sendCh,
recvCh: recvCh,
}
}