package vpn import ( "context" "fmt" "io" "strings" "sync" "golang.org/x/xerrors" "google.golang.org/protobuf/proto" "cdr.dev/slog" ) type SpeakerRole string type rpcMessage interface { proto.Message GetRpc() *RPC // EnsureRPC isn't autogenerated, but we'll manually add it for RPC types so that the speaker // can allocate the RPC. EnsureRPC() *RPC } func (t *TunnelMessage) EnsureRPC() *RPC { if t.Rpc == nil { t.Rpc = &RPC{} } return t.Rpc } func (m *ManagerMessage) EnsureRPC() *RPC { if m.Rpc == nil { m.Rpc = &RPC{} } return m.Rpc } // receivableRPCMessage is an rpcMessage that we can receive, and unmarshal, using generics, from a // byte stream. proto.Unmarshal requires us to have already allocated the memory for the message // type we are unmarshalling. All our message types are pointers like *TunnelMessage, so to // allocate, the compiler needs to know: // // a) that the type is a pointer type // b) what type it is pointing to // // So, this generic interface requires that the message is a pointer to the type RR. Then, we pass // both the receivableRPCMessage and RR as type constraints, so that we'll have access to the // underlying type when it comes time to allocate it. It's a bit messy, but the alternative is // reflection, which has its own challenges in understandability. type receivableRPCMessage[RR any] interface { rpcMessage *RR } const ( SpeakerRoleManager SpeakerRole = "manager" SpeakerRoleTunnel SpeakerRole = "tunnel" ) // speaker is an implementation of the CoderVPN protocol. It handles unary RPCs and their responses, // as well as the low-level serialization & deserialization to the ReadWriteCloser (rwc). // // ┌────────┐ sendCh // ◄─────│ ◄────────────────────────────────────────────────────────────────── ◄┐ // │ │ ▲ rpc requests // rwc │ serdes │ │ │ sendReply() // │ │ ┌───────────────────┐ ┌──────┼──────┐ // ──────► ┼────────► recvFromSerdes() │ rpc │rpc handling │ │ // └────────┘ recvCh │ ┼────────────► ◄──── unaryRPC() // │ │ responses │ │ │ // │ │ │ │ // │ │ └─────────────┘ ┌ ─ ─│─ ─ ─ ─ ─ ─ ─ ┐ // │ ┼──────────────────────────────────────────► request handling // └───────────────────┘ requests (outside speaker) // └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ // // speaker is implemented as a generic type that accepts the type of message we send (S), the type we receive (R), and // the underlying type that R points to (RR). The speaker is intended to be wrapped by another, non-generic type for // the role (manager or tunnel). E.g. Tunnel from this package. // // The serdes handles SERialiazation and DESerialization of the low level message types. The wrapping type may send // non-RPC messages (that is messages that don't expect an explicit reply) by sending on the sendCh. // // Unary RPCs are handled by the unaryRPC() function, which handles sending the message and waiting for the response. // // recvFromSerdes() reads all incoming messages from the serdes. If they are RPC responses, it dispatches them to the // waiting unaryRPC() function call, if any. If they are RPC requests or non-RPC messages, it wraps them in a request // struct and sends them over the requests chan. The manager/tunnel role type must read from this chan and handle // the requests. If they are RPC types, it should call sendReply() on the request with the reply message. type speaker[S rpcMessage, R receivableRPCMessage[RR], RR any] struct { serdes *serdes[S, R, RR] requests chan *request[S, R] logger slog.Logger nextMsgID uint64 ctx context.Context cancel context.CancelFunc sendCh chan<- S recvCh <-chan R recvLoopDone chan struct{} mu sync.Mutex responseChans map[uint64]chan R } // newSpeaker creates a new protocol speaker. func newSpeaker[S rpcMessage, R receivableRPCMessage[RR], RR any]( ctx context.Context, logger slog.Logger, conn io.ReadWriteCloser, me, them SpeakerRole, ) ( *speaker[S, R, RR], error, ) { ctx, cancel := context.WithCancel(ctx) if err := handshake(ctx, conn, logger, me, them); err != nil { cancel() return nil, xerrors.Errorf("handshake failed: %w", err) } sendCh := make(chan S) recvCh := make(chan R) s := &speaker[S, R, RR]{ serdes: newSerdes(ctx, logger, conn, sendCh, recvCh), logger: logger, requests: make(chan *request[S, R]), responseChans: make(map[uint64]chan R), nextMsgID: 1, ctx: ctx, cancel: cancel, sendCh: sendCh, recvCh: recvCh, recvLoopDone: make(chan struct{}), } return s, nil } // start starts the serialzation/deserialization. It's important this happens // after any assignments of the speaker to its owning Tunnel or Manager, since // the mutex is copied and that is not threadsafe. // nolint: revive func (s *speaker[_, _, _]) start() { s.serdes.start() go s.recvFromSerdes() } func (s *speaker[S, R, _]) recvFromSerdes() { defer close(s.recvLoopDone) defer close(s.requests) for { select { case <-s.ctx.Done(): s.logger.Debug(s.ctx, "recvFromSerdes context done while waiting for proto", slog.Error(s.ctx.Err())) return case msg, ok := <-s.recvCh: if !ok { s.logger.Debug(s.ctx, "recvCh is closed") return } rpc := msg.GetRpc() if rpc != nil && rpc.ResponseTo != 0 { // this is a unary response s.tryToDeliverResponse(msg) continue } req := &request[S, R]{ ctx: s.ctx, msg: msg, replyCh: s.sendCh, } select { case <-s.ctx.Done(): s.logger.Debug(s.ctx, "recvFromSerdes context done while waiting for request handler", slog.Error(s.ctx.Err())) return case s.requests <- req: } } } } // Close closes the speaker // nolint: revive func (s *speaker[_, _, _]) Close() error { s.cancel() err := s.serdes.Close() return err } // unaryRPC sends a request/response style RPC over the protocol, waits for the response, then // returns the response func (s *speaker[S, R, _]) unaryRPC(ctx context.Context, req S) (resp R, err error) { rpc := req.EnsureRPC() msgID, respCh := s.newRPC() rpc.MsgId = msgID logger := s.logger.With(slog.F("msg_id", msgID)) select { case <-ctx.Done(): return resp, ctx.Err() case <-s.ctx.Done(): return resp, xerrors.Errorf("vpn protocol closed: %w", s.ctx.Err()) case <-s.recvLoopDone: logger.Debug(s.ctx, "recvLoopDone while sending request") return resp, io.ErrUnexpectedEOF case s.sendCh <- req: logger.Debug(s.ctx, "sent rpc request", slog.F("req", req)) } select { case <-ctx.Done(): s.rmResponseChan(msgID) return resp, ctx.Err() case <-s.ctx.Done(): s.rmResponseChan(msgID) return resp, xerrors.Errorf("vpn protocol closed: %w", s.ctx.Err()) case <-s.recvLoopDone: logger.Debug(s.ctx, "recvLoopDone while waiting for response") return resp, io.ErrUnexpectedEOF case resp = <-respCh: logger.Debug(s.ctx, "got response", slog.F("resp", resp)) return resp, nil } } func (s *speaker[_, R, _]) newRPC() (uint64, chan R) { s.mu.Lock() defer s.mu.Unlock() msgID := s.nextMsgID s.nextMsgID++ c := make(chan R) s.responseChans[msgID] = c return msgID, c } func (s *speaker[_, _, _]) rmResponseChan(msgID uint64) { s.mu.Lock() defer s.mu.Unlock() delete(s.responseChans, msgID) } func (s *speaker[_, R, _]) tryToDeliverResponse(resp R) { msgID := resp.GetRpc().GetResponseTo() s.mu.Lock() defer s.mu.Unlock() c, ok := s.responseChans[msgID] if ok { c <- resp // Remove the channel since we delivered a response. This ensures that each response channel // gets _at most_ one response. Since the channels are buffered with size 1, send will // never block. delete(s.responseChans, msgID) } } // handshake performs the initial CoderVPN protocol handshake over the given conn func handshake( ctx context.Context, conn io.ReadWriteCloser, logger slog.Logger, me, them SpeakerRole, ) error { // read and write simultaneously to avoid deadlocking if the conn is not buffered errCh := make(chan error, 2) go func() { ours := headerString(me, CurrentSupportedVersions) _, err := conn.Write([]byte(ours)) logger.Debug(ctx, "wrote out header") if err != nil { err = xerrors.Errorf("write header: %w", err) } errCh <- err }() headerCh := make(chan string, 1) go func() { // we can't use bufio.Scanner here because we need to ensure we don't read beyond the // first newline. So, we'll read one byte at a time. It's inefficient, but the initial // header is only a few characters, so we'll keep this code simple. buf := make([]byte, 256) have := 0 for { _, err := conn.Read(buf[have : have+1]) if err != nil { errCh <- xerrors.Errorf("read header: %w", err) return } if buf[have] == '\n' { logger.Debug(ctx, "got newline header delimiter") // use have (not have+1) since we don't want the delimiter for verification. headerCh <- string(buf[:have]) return } have++ if have >= len(buf) { errCh <- xerrors.Errorf("header malformed or too large: %s", string(buf)) return } } }() writeOK := false theirHeader := "" readOK := false for !(readOK && writeOK) { select { case <-ctx.Done(): _ = conn.Close() // ensure our read/write goroutines get a chance to clean up return ctx.Err() case err := <-errCh: if err == nil { // write goroutine sends nil when completing successfully. logger.Debug(ctx, "write ok") writeOK = true continue } _ = conn.Close() return err case theirHeader = <-headerCh: logger.Debug(ctx, "read ok") readOK = true } } logger.Debug(ctx, "handshake read/write complete", slog.F("their_header", theirHeader)) gotVersion, err := validateHeader(theirHeader, them, CurrentSupportedVersions) if err != nil { return xerrors.Errorf("validate header (%s): %w", theirHeader, err) } logger.Debug(ctx, "handshake validated", slog.F("common_version", gotVersion)) // TODO: actually use the common version to perform different behavior once // we have multiple versions return nil } const headerPreamble = "codervpn" func headerString(role SpeakerRole, versions RPCVersionList) string { return fmt.Sprintf("%s %s %s\n", headerPreamble, role, versions.String()) } func validateHeader(header string, expectedRole SpeakerRole, supportedVersions RPCVersionList) (RPCVersion, error) { parts := strings.Split(header, " ") if len(parts) != 3 { return RPCVersion{}, xerrors.New("wrong number of parts") } if parts[0] != headerPreamble { return RPCVersion{}, xerrors.New("invalid preamble") } if parts[1] != string(expectedRole) { return RPCVersion{}, xerrors.New("unexpected role") } otherVersions, err := ParseRPCVersionList(parts[2]) if err != nil { return RPCVersion{}, xerrors.Errorf("parse version list %q: %w", parts[2], err) } compatibleVersion, ok := supportedVersions.IsCompatibleWith(otherVersions) if !ok { return RPCVersion{}, xerrors.Errorf("current supported versions %q is not compatible with peer versions %q", supportedVersions.String(), otherVersions.String()) } return compatibleVersion, nil } type request[S rpcMessage, R rpcMessage] struct { ctx context.Context msg R replyCh chan<- S } func (r *request[S, _]) sendReply(reply S) error { rrpc := reply.EnsureRPC() mrpc := r.msg.GetRpc() if mrpc == nil { return xerrors.Errorf("message didn't want a reply") } rrpc.ResponseTo = mrpc.MsgId select { case <-r.ctx.Done(): return r.ctx.Err() case r.replyCh <- reply: } return nil }