package peer import ( "bytes" "context" "crypto/rand" "io" "sync" "time" "github.com/pion/logging" "github.com/pion/webrtc/v3" "go.uber.org/atomic" "golang.org/x/xerrors" "cdr.dev/slog" ) var ( // ErrDisconnected occurs when the connection has disconnected. // The connection will be attempting to reconnect at this point. ErrDisconnected = xerrors.New("connection is disconnected") // ErrFailed occurs when the connection has failed. // The connection will not retry after this point. ErrFailed = xerrors.New("connection has failed") // ErrClosed occurs when the connection was closed. It wraps io.EOF // to fulfill expected read errors from closed pipes. ErrClosed = xerrors.Errorf("connection was closed: %w", io.EOF) // The amount of random bytes sent in a ping. pingDataLength = 64 ) // Client creates a new client connection. func Client(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) { return newWithClientOrServer(servers, true, opts) } // Server creates a new server connection. func Server(servers []webrtc.ICEServer, opts *ConnOptions) (*Conn, error) { return newWithClientOrServer(servers, false, opts) } // newWithClientOrServer constructs a new connection with the client option. // nolint:revive func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOptions) (*Conn, error) { if opts == nil { opts = &ConnOptions{} } opts.SettingEngine.DetachDataChannels() logger := logging.NewDefaultLoggerFactory() logger.DefaultLogLevel = logging.LogLevelDisabled opts.SettingEngine.LoggerFactory = logger api := webrtc.NewAPI(webrtc.WithSettingEngine(opts.SettingEngine)) rtc, err := api.NewPeerConnection(webrtc.Configuration{ ICEServers: servers, }) if err != nil { return nil, xerrors.Errorf("create peer connection: %w", err) } conn := &Conn{ pingChannelID: 1, pingEchoChannelID: 2, rtc: rtc, offerer: client, closed: make(chan struct{}), closedRTC: make(chan struct{}), closedICE: make(chan struct{}), dcOpenChannel: make(chan *webrtc.DataChannel, 8), dcDisconnectChannel: make(chan struct{}), dcFailedChannel: make(chan struct{}), localCandidateChannel: make(chan webrtc.ICECandidateInit), localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1), negotiated: make(chan struct{}), remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1), settingEngine: opts.SettingEngine, } conn.loggerValue.Store(opts.Logger) if client { // If we're the client, we want to flip the echo and // ping channel IDs so pings don't accidentally hit each other. conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID } err = conn.init() if err != nil { return nil, xerrors.Errorf("init: %w", err) } return conn, nil } type ConnOptions struct { Logger slog.Logger // Enables customization on the underlying WebRTC connection. SettingEngine webrtc.SettingEngine } // Conn represents a WebRTC peer connection. // // This struct wraps webrtc.PeerConnection to add bidirectional pings, // concurrent-safe webrtc.DataChannel, and standardized errors for connection state. type Conn struct { rtc *webrtc.PeerConnection // Determines whether this connection will send the offer or the answer. offerer bool closed chan struct{} closedRTC chan struct{} closedRTCMutex sync.Mutex closedICE chan struct{} closedICEMutex sync.Mutex closeMutex sync.Mutex closeError error dcCreateMutex sync.Mutex dcOpenChannel chan *webrtc.DataChannel dcDisconnectChannel chan struct{} dcDisconnectListeners atomic.Uint32 dcFailedChannel chan struct{} dcFailedListeners atomic.Uint32 dcClosedWaitGroup sync.WaitGroup localCandidateChannel chan webrtc.ICECandidateInit localSessionDescriptionChannel chan webrtc.SessionDescription remoteSessionDescriptionChannel chan webrtc.SessionDescription negotiated chan struct{} loggerValue atomic.Value settingEngine webrtc.SettingEngine pingChannelID uint16 pingEchoChannelID uint16 pingEchoChan *Channel pingEchoOnce sync.Once pingEchoError error pingMutex sync.Mutex pingOnce sync.Once pingChan *Channel pingError error } func (c *Conn) logger() slog.Logger { log, valid := c.loggerValue.Load().(slog.Logger) if !valid { return slog.Logger{} } return log } func (c *Conn) init() error { c.rtc.OnNegotiationNeeded(c.negotiate) c.rtc.OnICEConnectionStateChange(func(iceConnectionState webrtc.ICEConnectionState) { c.closedICEMutex.Lock() defer c.closedICEMutex.Unlock() select { case <-c.closedICE: // Don't log more state changes if we've already closed. return default: c.logger().Debug(context.Background(), "ice connection state updated", slog.F("state", iceConnectionState)) if iceConnectionState == webrtc.ICEConnectionStateClosed { // pion/webrtc can update this state multiple times. // A connection can never become un-closed, so we // close the channel if it isn't already. close(c.closedICE) } } }) c.rtc.OnICEGatheringStateChange(func(iceGatherState webrtc.ICEGathererState) { c.closedICEMutex.Lock() defer c.closedICEMutex.Unlock() select { case <-c.closedICE: // Don't log more state changes if we've already closed. return default: c.logger().Debug(context.Background(), "ice gathering state updated", slog.F("state", iceGatherState)) if iceGatherState == webrtc.ICEGathererStateClosed { // pion/webrtc can update this state multiple times. // A connection can never become un-closed, so we // close the channel if it isn't already. close(c.closedICE) } } }) c.rtc.OnConnectionStateChange(func(peerConnectionState webrtc.PeerConnectionState) { go func() { c.closeMutex.Lock() defer c.closeMutex.Unlock() if c.isClosed() { return } c.logger().Debug(context.Background(), "rtc connection updated", slog.F("state", peerConnectionState)) }() switch peerConnectionState { case webrtc.PeerConnectionStateDisconnected: for i := 0; i < int(c.dcDisconnectListeners.Load()); i++ { select { case c.dcDisconnectChannel <- struct{}{}: default: } } case webrtc.PeerConnectionStateFailed: for i := 0; i < int(c.dcFailedListeners.Load()); i++ { select { case c.dcFailedChannel <- struct{}{}: default: } } case webrtc.PeerConnectionStateClosed: // pion/webrtc can update this state multiple times. // A connection can never become un-closed, so we // close the channel if it isn't already. c.closedRTCMutex.Lock() defer c.closedRTCMutex.Unlock() select { case <-c.closedRTC: default: close(c.closedRTC) } } }) // These functions need to check if the conn is closed, because they can be // called after being closed. c.rtc.OnSignalingStateChange(func(signalState webrtc.SignalingState) { c.logger().Debug(context.Background(), "signaling state updated", slog.F("state", signalState)) }) c.rtc.SCTP().Transport().OnStateChange(func(dtlsTransportState webrtc.DTLSTransportState) { c.logger().Debug(context.Background(), "dtls transport state updated", slog.F("state", dtlsTransportState)) }) c.rtc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(func(candidatePair *webrtc.ICECandidatePair) { c.logger().Debug(context.Background(), "selected candidate pair changed", slog.F("local", candidatePair.Local), slog.F("remote", candidatePair.Remote)) }) c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) { if iceCandidate == nil { return } // Run this in a goroutine so we don't block pion/webrtc // from continuing. go func() { c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate)) select { case <-c.closed: case c.localCandidateChannel <- iceCandidate.ToJSON(): } }() }) c.rtc.OnDataChannel(func(dc *webrtc.DataChannel) { go func() { select { case <-c.closed: case c.dcOpenChannel <- dc: } }() }) _, err := c.pingChannel() if err != nil { return err } _, err = c.pingEchoChannel() if err != nil { return err } return nil } // negotiate is triggered when a connection is ready to be established. // See trickle ICE for the expected exchange: https://webrtchacks.com/trickle-ice/ func (c *Conn) negotiate() { c.logger().Debug(context.Background(), "negotiating") // ICE candidates cannot be added until SessionDescriptions have been // exchanged between peers. defer func() { select { case <-c.negotiated: default: close(c.negotiated) } }() if c.offerer { offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{}) if err != nil { _ = c.CloseWithError(xerrors.Errorf("create offer: %w", err)) return } // pion/webrtc will panic if Close is called while this // function is being executed. c.closeMutex.Lock() err = c.rtc.SetLocalDescription(offer) c.closeMutex.Unlock() if err != nil { _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) return } c.logger().Debug(context.Background(), "sending offer", slog.F("offer", offer)) select { case <-c.closed: return case c.localSessionDescriptionChannel <- offer: } c.logger().Debug(context.Background(), "sent offer") } var sessionDescription webrtc.SessionDescription c.logger().Debug(context.Background(), "awaiting remote description...") select { case <-c.closed: return case sessionDescription = <-c.remoteSessionDescriptionChannel: } c.logger().Debug(context.Background(), "setting remote description") err := c.rtc.SetRemoteDescription(sessionDescription) if err != nil { _ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err)) return } if !c.offerer { answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{}) if err != nil { _ = c.CloseWithError(xerrors.Errorf("create answer: %w", err)) return } // pion/webrtc will panic if Close is called while this // function is being executed. c.closeMutex.Lock() err = c.rtc.SetLocalDescription(answer) c.closeMutex.Unlock() if err != nil { _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) return } c.logger().Debug(context.Background(), "sending answer", slog.F("answer", answer)) select { case <-c.closed: return case c.localSessionDescriptionChannel <- answer: } c.logger().Debug(context.Background(), "sent answer") } } // AddRemoteCandidate adds a remote candidate to the RTC connection. func (c *Conn) AddRemoteCandidate(i webrtc.ICECandidateInit) { if c.isClosed() { return } // This must occur in a goroutine to allow the SessionDescriptions // to be exchanged first. go func() { select { case <-c.closed: case <-c.negotiated: } if c.isClosed() { return } c.logger().Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate)) err := c.rtc.AddICECandidate(i) if err != nil { if c.rtc.ConnectionState() == webrtc.PeerConnectionStateClosed { return } _ = c.CloseWithError(xerrors.Errorf("accept candidate: %w", err)) } }() } // SetRemoteSessionDescription sets the remote description for the WebRTC connection. func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) { select { case <-c.closed: case c.remoteSessionDescriptionChannel <- sessionDescription: } } // LocalSessionDescription returns a channel that emits a session description // when one is required to be exchanged. func (c *Conn) LocalSessionDescription() <-chan webrtc.SessionDescription { return c.localSessionDescriptionChannel } // LocalCandidate returns a channel that emits when a local candidate // needs to be exchanged with a remote connection. func (c *Conn) LocalCandidate() <-chan webrtc.ICECandidateInit { return c.localCandidateChannel } func (c *Conn) pingChannel() (*Channel, error) { c.pingOnce.Do(func() { c.pingChan, c.pingError = c.dialChannel(context.Background(), "ping", &ChannelOptions{ ID: c.pingChannelID, Negotiated: true, OpenOnDisconnect: true, }) if c.pingError != nil { return } }) return c.pingChan, c.pingError } func (c *Conn) pingEchoChannel() (*Channel, error) { c.pingEchoOnce.Do(func() { c.pingEchoChan, c.pingEchoError = c.dialChannel(context.Background(), "echo", &ChannelOptions{ ID: c.pingEchoChannelID, Negotiated: true, OpenOnDisconnect: true, }) if c.pingEchoError != nil { return } go func() { for { data := make([]byte, pingDataLength) bytesRead, err := c.pingEchoChan.Read(data) if err != nil { _ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err)) return } _, err = c.pingEchoChan.Write(data[:bytesRead]) if err != nil { _ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err)) return } } }() }) return c.pingEchoChan, c.pingEchoError } // SetConfiguration applies options to the WebRTC connection. // Generally used for updating transport options, like ICE servers. func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error { return c.rtc.SetConfiguration(configuration) } // Accept blocks waiting for a channel to be opened. func (c *Conn) Accept(ctx context.Context) (*Channel, error) { var dataChannel *webrtc.DataChannel select { case <-ctx.Done(): return nil, ctx.Err() case <-c.closed: return nil, c.closeError case dataChannel = <-c.dcOpenChannel: } return newChannel(c, dataChannel, &ChannelOptions{}), nil } // CreateChannel creates a new DataChannel. func (c *Conn) CreateChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { if opts == nil { opts = &ChannelOptions{} } if opts.ID == c.pingChannelID || opts.ID == c.pingEchoChannelID { return nil, xerrors.Errorf("datachannel id %d and %d are reserved for ping", c.pingChannelID, c.pingEchoChannelID) } return c.dialChannel(ctx, label, opts) } func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) { // pion/webrtc is slower when opening multiple channels // in parallel than it is sequentially. c.dcCreateMutex.Lock() defer c.dcCreateMutex.Unlock() c.logger().Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts)) var id *uint16 if opts.ID != 0 { id = &opts.ID } ordered := true if opts.Unordered { ordered = false } if opts.OpenOnDisconnect && !opts.Negotiated { return nil, xerrors.New("OpenOnDisconnect is only allowed for Negotiated channels") } if c.isClosed() { return nil, xerrors.Errorf("closed: %w", c.closeError) } dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{ ID: id, Negotiated: &opts.Negotiated, Ordered: &ordered, Protocol: &opts.Protocol, }) if err != nil { return nil, xerrors.Errorf("create data channel: %w", err) } return newChannel(c, dataChannel, opts), nil } // Ping returns the duration it took to round-trip data. // Multiple pings cannot occur at the same time, so this function will block. func (c *Conn) Ping() (time.Duration, error) { // Pings are not async, so we need a mutex. c.pingMutex.Lock() defer c.pingMutex.Unlock() ping, err := c.pingChannel() if err != nil { return 0, xerrors.Errorf("get ping channel: %w", err) } pingDataSent := make([]byte, pingDataLength) _, err = rand.Read(pingDataSent) if err != nil { return 0, xerrors.Errorf("read random ping data: %w", err) } start := time.Now() _, err = ping.Write(pingDataSent) if err != nil { return 0, xerrors.Errorf("send ping: %w", err) } c.logger().Debug(context.Background(), "wrote ping", slog.F("connection_state", c.rtc.ConnectionState())) pingDataReceived := make([]byte, pingDataLength) _, err = ping.Read(pingDataReceived) if err != nil { return 0, xerrors.Errorf("read ping: %w", err) } end := time.Now() if !bytes.Equal(pingDataSent, pingDataReceived) { return 0, xerrors.Errorf("ping data inconsistency sent != received") } return end.Sub(start), nil } func (c *Conn) Closed() <-chan struct{} { return c.closed } // Close closes the connection and frees all associated resources. func (c *Conn) Close() error { return c.CloseWithError(nil) } func (c *Conn) isClosed() bool { select { case <-c.closed: return true default: return false } } // CloseWithError closes the connection; subsequent reads/writes will return the error err. func (c *Conn) CloseWithError(err error) error { c.closeMutex.Lock() defer c.closeMutex.Unlock() if c.isClosed() { return c.closeError } logger := c.logger() logger.Debug(context.Background(), "closing conn with error", slog.Error(err)) if err == nil { c.closeError = ErrClosed } else { c.closeError = err } if ch, _ := c.pingChannel(); ch != nil { _ = ch.closeWithError(c.closeError) } // If the WebRTC connection has already been closed (due to failure or disconnect), // this call will return an error that isn't typed. We don't check the error because // closing an already closed connection isn't an issue for us. _ = c.rtc.Close() // Waiting for pion/webrtc to report closed state on both of these // ensures no goroutine leaks. if c.rtc.ConnectionState() != webrtc.PeerConnectionStateNew { logger.Debug(context.Background(), "waiting for rtc connection close...") <-c.closedRTC } if c.rtc.ICEConnectionState() != webrtc.ICEConnectionStateNew { logger.Debug(context.Background(), "waiting for ice connection close...") <-c.closedICE } // Waits for all DataChannels to exit before officially labeling as closed. // All logging, goroutines, and async functionality is cleaned up after this. c.dcClosedWaitGroup.Wait() // Disable logging! c.loggerValue.Store(slog.Logger{}) logger.Sync() logger.Debug(context.Background(), "closed") close(c.closed) return err }