package database import ( "context" "database/sql" "errors" "sync" "time" "github.com/google/uuid" "github.com/lib/pq" "golang.org/x/xerrors" ) // Listener represents a pubsub handler. type Listener func(ctx context.Context, message []byte) // ListenerWithErr represents a pubsub handler that can also receive error // indications type ListenerWithErr func(ctx context.Context, message []byte, err error) // ErrDroppedMessages is sent to ListenerWithErr if messages are dropped or // might have been dropped. var ErrDroppedMessages = xerrors.New("dropped messages") // Pubsub is a generic interface for broadcasting and receiving messages. // Implementors should assume high-availability with the backing implementation. type Pubsub interface { Subscribe(event string, listener Listener) (cancel func(), err error) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) Publish(event string, message []byte) error Close() error } // msgOrErr either contains a message or an error type msgOrErr struct { msg []byte err error } // msgQueue implements a fixed length queue with the ability to replace elements // after they are queued (but before they are dequeued). // // The purpose of this data structure is to build something that works a bit // like a golang channel, but if the queue is full, then we can replace the // last element with an error so that the subscriber can get notified that some // messages were dropped, all without blocking. type msgQueue struct { ctx context.Context cond *sync.Cond q [PubsubBufferSize]msgOrErr front int size int closed bool l Listener le ListenerWithErr } func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue { if l == nil && le == nil { panic("l or le must be non-nil") } q := &msgQueue{ ctx: ctx, cond: sync.NewCond(&sync.Mutex{}), l: l, le: le, } go q.run() return q } func (q *msgQueue) run() { for { // wait until there is something on the queue or we are closed q.cond.L.Lock() for q.size == 0 && !q.closed { q.cond.Wait() } if q.closed { q.cond.L.Unlock() return } item := q.q[q.front] q.front = (q.front + 1) % PubsubBufferSize q.size-- q.cond.L.Unlock() // process item without holding lock if item.err == nil { // real message if q.l != nil { q.l(q.ctx, item.msg) continue } if q.le != nil { q.le(q.ctx, item.msg, nil) continue } // unhittable continue } // if the listener wants errors, send it. if q.le != nil { q.le(q.ctx, nil, item.err) } } } func (q *msgQueue) enqueue(msg []byte) { q.cond.L.Lock() defer q.cond.L.Unlock() if q.size == PubsubBufferSize { // queue is full, so we're going to drop the msg we got called with. // We also need to record that messages are being dropped, which we // do at the last message in the queue. This potentially makes us // lose 2 messages instead of one, but it's more important at this // point to warn the subscriber that they're losing messages so they // can do something about it. back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize q.q[back].msg = nil q.q[back].err = ErrDroppedMessages return } // queue is not full, insert the message next := (q.front + q.size) % PubsubBufferSize q.q[next].msg = msg q.q[next].err = nil q.size++ q.cond.Broadcast() } func (q *msgQueue) close() { q.cond.L.Lock() defer q.cond.L.Unlock() defer q.cond.Broadcast() q.closed = true } // dropped records an error in the queue that messages might have been dropped func (q *msgQueue) dropped() { q.cond.L.Lock() defer q.cond.L.Unlock() if q.size == PubsubBufferSize { // queue is full, but we need to record that messages are being dropped, // which we do at the last message in the queue. This potentially drops // another message, but it's more important for the subscriber to know. back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize q.q[back].msg = nil q.q[back].err = ErrDroppedMessages return } // queue is not full, insert the error next := (q.front + q.size) % PubsubBufferSize q.q[next].msg = nil q.q[next].err = ErrDroppedMessages q.size++ q.cond.Broadcast() } // Pubsub implementation using PostgreSQL. type pgPubsub struct { ctx context.Context cancel context.CancelFunc listenDone chan struct{} pgListener *pq.Listener db *sql.DB mut sync.Mutex queues map[string]map[uuid.UUID]*msgQueue } // PubsubBufferSize is the maximum number of unhandled messages we will buffer // for a subscriber before dropping messages. const PubsubBufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil)) } func (p *pgPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener)) } func (p *pgPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { p.mut.Lock() defer p.mut.Unlock() defer func() { if err != nil { // if we hit an error, we need to close the queue so we don't // leak its goroutine. newQ.close() } }() err = p.pgListener.Listen(event) if errors.Is(err, pq.ErrChannelAlreadyOpen) { // It's ok if it's already open! err = nil } if err != nil { return nil, xerrors.Errorf("listen: %w", err) } var eventQs map[uuid.UUID]*msgQueue var ok bool if eventQs, ok = p.queues[event]; !ok { eventQs = make(map[uuid.UUID]*msgQueue) p.queues[event] = eventQs } id := uuid.New() eventQs[id] = newQ return func() { p.mut.Lock() defer p.mut.Unlock() listeners := p.queues[event] q := listeners[id] q.close() delete(listeners, id) if len(listeners) == 0 { _ = p.pgListener.Unlisten(event) } }, nil } func (p *pgPubsub) Publish(event string, message []byte) error { // This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't // support the first parameter being a prepared statement. //nolint:gosec _, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) if err != nil { return xerrors.Errorf("exec pg_notify: %w", err) } return nil } // Close closes the pubsub instance. func (p *pgPubsub) Close() error { p.cancel() err := p.pgListener.Close() <-p.listenDone return err } // listen begins receiving messages on the pq listener. func (p *pgPubsub) listen() { defer close(p.listenDone) defer p.pgListener.Close() var ( notif *pq.Notification ok bool ) for { select { case <-p.ctx.Done(): return case notif, ok = <-p.pgListener.Notify: if !ok { return } } // A nil notification can be dispatched on reconnect. if notif == nil { p.recordReconnect() continue } p.listenReceive(notif) } } func (p *pgPubsub) listenReceive(notif *pq.Notification) { p.mut.Lock() defer p.mut.Unlock() queues, ok := p.queues[notif.Channel] if !ok { return } extra := []byte(notif.Extra) for _, q := range queues { q.enqueue(extra) } } func (p *pgPubsub) recordReconnect() { p.mut.Lock() defer p.mut.Unlock() for _, listeners := range p.queues { for _, q := range listeners { q.dropped() } } } // NewPubsub creates a new Pubsub implementation using a PostgreSQL connection. func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) { // Creates a new listener using pq. errCh := make(chan error) listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) { // This callback gets events whenever the connection state changes. // Don't send if the errChannel has already been closed. select { case <-errCh: return default: errCh <- err close(errCh) } }) select { case err := <-errCh: if err != nil { _ = listener.Close() return nil, xerrors.Errorf("create pq listener: %w", err) } case <-ctx.Done(): _ = listener.Close() return nil, ctx.Err() } // Start a new context that will be canceled when the pubsub is closed. ctx, cancel := context.WithCancel(context.Background()) pgPubsub := &pgPubsub{ ctx: ctx, cancel: cancel, listenDone: make(chan struct{}), db: database, pgListener: listener, queues: make(map[string]map[uuid.UUID]*msgQueue), } go pgPubsub.listen() return pgPubsub, nil }