fix: fix Listen/Unlisten race on Pubsub (#15315)

Fixes #15312

When we need to `Unlisten()` for an event, instead of immediately removing the event from the `p.queues`, we store a channel to signal any goroutines trying to Subscribe to the same event when we are done. On `Subscribe`, if the channel is present, wait for it before calling `Listen` to ensure the ordering is correct.
This commit is contained in:
Spike Curtis
2024-11-01 14:35:26 +04:00
committed by GitHub
parent fbbefa228d
commit 005ea536a5
2 changed files with 155 additions and 31 deletions

View File

@ -11,7 +11,6 @@ import (
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
@ -188,6 +187,19 @@ func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
return l.Notify
}
type queueSet struct {
m map[*msgQueue]struct{}
// unlistenInProgress will be non-nil if another goroutine is unlistening for the event this
// queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done.
unlistenInProgress chan struct{}
}
func newQueueSet() *queueSet {
return &queueSet{
m: make(map[*msgQueue]struct{}),
}
}
// PGPubsub is a pubsub implementation using PostgreSQL.
type PGPubsub struct {
logger slog.Logger
@ -196,7 +208,7 @@ type PGPubsub struct {
db *sql.DB
qMu sync.Mutex
queues map[string]map[uuid.UUID]*msgQueue
queues map[string]*queueSet
// making the close state its own mutex domain simplifies closing logic so
// that we don't have to hold the qMu --- which could block processing
@ -243,6 +255,48 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
}
}()
var (
unlistenInProgress <-chan struct{}
// MUST hold the p.qMu lock to manipulate this!
qs *queueSet
)
func() {
p.qMu.Lock()
defer p.qMu.Unlock()
var ok bool
if qs, ok = p.queues[event]; !ok {
qs = newQueueSet()
p.queues[event] = qs
}
qs.m[newQ] = struct{}{}
unlistenInProgress = qs.unlistenInProgress
}()
// NOTE there cannot be any `return` statements between here and the next +-+, otherwise the
// assumptions the defer makes could be violated
if unlistenInProgress != nil {
// We have to wait here because we don't want our `Listen` call to happen before the other
// goroutine calls `Unlisten`. That would result in this subscription not getting any
// events. c.f. https://github.com/coder/coder/issues/15312
p.logger.Debug(context.Background(), "waiting for Unlisten in progress", slog.F("event", event))
<-unlistenInProgress
p.logger.Debug(context.Background(), "unlistening complete", slog.F("event", event))
}
// +-+ (see above)
defer func() {
if err != nil {
p.qMu.Lock()
defer p.qMu.Unlock()
delete(qs.m, newQ)
if len(qs.m) == 0 {
// we know that newQ was in the queueSet since we last unlocked, so there cannot
// have been any _new_ goroutines trying to Unlisten(). Therefore, if the queueSet
// is now empty, it's safe to delete.
delete(p.queues, event)
}
}
}()
// The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches
// notifies. We need to avoid holding the mutex while this happens, since holding the mutex
// blocks reading notifications and can deadlock the pgListener.
@ -258,32 +312,40 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
if err != nil {
return nil, xerrors.Errorf("listen: %w", err)
}
p.qMu.Lock()
defer p.qMu.Unlock()
var eventQs map[uuid.UUID]*msgQueue
var ok bool
if eventQs, ok = p.queues[event]; !ok {
eventQs = make(map[uuid.UUID]*msgQueue)
p.queues[event] = eventQs
}
id := uuid.New()
eventQs[id] = newQ
return func() {
p.qMu.Lock()
listeners := p.queues[event]
q := listeners[id]
q.close()
delete(listeners, id)
if len(listeners) == 0 {
delete(p.queues, event)
}
listenerCount := len(listeners)
p.qMu.Unlock()
// as above, we must not hold the lock while calling into pgListener
var unlistening chan struct{}
func() {
p.qMu.Lock()
defer p.qMu.Unlock()
newQ.close()
qSet, ok := p.queues[event]
if !ok {
p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event))
return
}
delete(qSet.m, newQ)
if len(qSet.m) == 0 {
unlistening = make(chan struct{})
qSet.unlistenInProgress = unlistening
}
}()
if listenerCount == 0 {
// as above, we must not hold the lock while calling into pgListener
if unlistening != nil {
uErr := p.pgListener.Unlisten(event)
close(unlistening)
// we can now delete the queueSet if it is empty.
func() {
p.qMu.Lock()
defer p.qMu.Unlock()
qSet, ok := p.queues[event]
if ok && len(qSet.m) == 0 {
p.logger.Debug(context.Background(), "removing queueSet", slog.F("event", event))
delete(p.queues, event)
}
}()
p.closeMu.Lock()
defer p.closeMu.Unlock()
if uErr != nil && !p.closedListener {
@ -361,12 +423,12 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
p.qMu.Lock()
defer p.qMu.Unlock()
queues, ok := p.queues[notif.Channel]
qSet, ok := p.queues[notif.Channel]
if !ok {
return
}
extra := []byte(notif.Extra)
for _, q := range queues {
for q := range qSet.m {
q.enqueue(extra)
}
}
@ -374,8 +436,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
func (p *PGPubsub) recordReconnect() {
p.qMu.Lock()
defer p.qMu.Unlock()
for _, listeners := range p.queues {
for _, q := range listeners {
for _, qSet := range p.queues {
for q := range qSet.m {
q.dropped()
}
}
@ -590,8 +652,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
p.qMu.Lock()
events := len(p.queues)
subs := 0
for _, subscriberMap := range p.queues {
subs += len(subscriberMap)
for _, qSet := range p.queues {
subs += len(qSet.m)
}
p.qMu.Unlock()
metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs))
@ -629,7 +691,7 @@ func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub {
logger: logger,
listenDone: make(chan struct{}),
db: db,
queues: make(map[string]map[uuid.UUID]*msgQueue),
queues: make(map[string]*queueSet),
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")),
publishesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{