mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
fix: fix Listen/Unlisten race on Pubsub (#15315)
Fixes #15312 When we need to `Unlisten()` for an event, instead of immediately removing the event from the `p.queues`, we store a channel to signal any goroutines trying to Subscribe to the same event when we are done. On `Subscribe`, if the channel is present, wait for it before calling `Listen` to ensure the ordering is correct.
This commit is contained in:
@ -11,7 +11,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
@ -188,6 +187,19 @@ func (l pqListenerShim) NotifyChan() <-chan *pq.Notification {
|
||||
return l.Notify
|
||||
}
|
||||
|
||||
type queueSet struct {
|
||||
m map[*msgQueue]struct{}
|
||||
// unlistenInProgress will be non-nil if another goroutine is unlistening for the event this
|
||||
// queueSet corresponds to. If non-nil, that goroutine will close the channel when it is done.
|
||||
unlistenInProgress chan struct{}
|
||||
}
|
||||
|
||||
func newQueueSet() *queueSet {
|
||||
return &queueSet{
|
||||
m: make(map[*msgQueue]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// PGPubsub is a pubsub implementation using PostgreSQL.
|
||||
type PGPubsub struct {
|
||||
logger slog.Logger
|
||||
@ -196,7 +208,7 @@ type PGPubsub struct {
|
||||
db *sql.DB
|
||||
|
||||
qMu sync.Mutex
|
||||
queues map[string]map[uuid.UUID]*msgQueue
|
||||
queues map[string]*queueSet
|
||||
|
||||
// making the close state its own mutex domain simplifies closing logic so
|
||||
// that we don't have to hold the qMu --- which could block processing
|
||||
@ -243,6 +255,48 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
|
||||
}
|
||||
}()
|
||||
|
||||
var (
|
||||
unlistenInProgress <-chan struct{}
|
||||
// MUST hold the p.qMu lock to manipulate this!
|
||||
qs *queueSet
|
||||
)
|
||||
func() {
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
|
||||
var ok bool
|
||||
if qs, ok = p.queues[event]; !ok {
|
||||
qs = newQueueSet()
|
||||
p.queues[event] = qs
|
||||
}
|
||||
qs.m[newQ] = struct{}{}
|
||||
unlistenInProgress = qs.unlistenInProgress
|
||||
}()
|
||||
// NOTE there cannot be any `return` statements between here and the next +-+, otherwise the
|
||||
// assumptions the defer makes could be violated
|
||||
if unlistenInProgress != nil {
|
||||
// We have to wait here because we don't want our `Listen` call to happen before the other
|
||||
// goroutine calls `Unlisten`. That would result in this subscription not getting any
|
||||
// events. c.f. https://github.com/coder/coder/issues/15312
|
||||
p.logger.Debug(context.Background(), "waiting for Unlisten in progress", slog.F("event", event))
|
||||
<-unlistenInProgress
|
||||
p.logger.Debug(context.Background(), "unlistening complete", slog.F("event", event))
|
||||
}
|
||||
// +-+ (see above)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
delete(qs.m, newQ)
|
||||
if len(qs.m) == 0 {
|
||||
// we know that newQ was in the queueSet since we last unlocked, so there cannot
|
||||
// have been any _new_ goroutines trying to Unlisten(). Therefore, if the queueSet
|
||||
// is now empty, it's safe to delete.
|
||||
delete(p.queues, event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches
|
||||
// notifies. We need to avoid holding the mutex while this happens, since holding the mutex
|
||||
// blocks reading notifications and can deadlock the pgListener.
|
||||
@ -258,32 +312,40 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(),
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen: %w", err)
|
||||
}
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
|
||||
var eventQs map[uuid.UUID]*msgQueue
|
||||
var ok bool
|
||||
if eventQs, ok = p.queues[event]; !ok {
|
||||
eventQs = make(map[uuid.UUID]*msgQueue)
|
||||
p.queues[event] = eventQs
|
||||
}
|
||||
id := uuid.New()
|
||||
eventQs[id] = newQ
|
||||
return func() {
|
||||
p.qMu.Lock()
|
||||
listeners := p.queues[event]
|
||||
q := listeners[id]
|
||||
q.close()
|
||||
delete(listeners, id)
|
||||
if len(listeners) == 0 {
|
||||
delete(p.queues, event)
|
||||
}
|
||||
listenerCount := len(listeners)
|
||||
p.qMu.Unlock()
|
||||
// as above, we must not hold the lock while calling into pgListener
|
||||
var unlistening chan struct{}
|
||||
func() {
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
newQ.close()
|
||||
qSet, ok := p.queues[event]
|
||||
if !ok {
|
||||
p.logger.Critical(context.Background(), "event was removed before cancel", slog.F("event", event))
|
||||
return
|
||||
}
|
||||
delete(qSet.m, newQ)
|
||||
if len(qSet.m) == 0 {
|
||||
unlistening = make(chan struct{})
|
||||
qSet.unlistenInProgress = unlistening
|
||||
}
|
||||
}()
|
||||
|
||||
if listenerCount == 0 {
|
||||
// as above, we must not hold the lock while calling into pgListener
|
||||
if unlistening != nil {
|
||||
uErr := p.pgListener.Unlisten(event)
|
||||
close(unlistening)
|
||||
// we can now delete the queueSet if it is empty.
|
||||
func() {
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
qSet, ok := p.queues[event]
|
||||
if ok && len(qSet.m) == 0 {
|
||||
p.logger.Debug(context.Background(), "removing queueSet", slog.F("event", event))
|
||||
delete(p.queues, event)
|
||||
}
|
||||
}()
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
if uErr != nil && !p.closedListener {
|
||||
@ -361,12 +423,12 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
|
||||
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
queues, ok := p.queues[notif.Channel]
|
||||
qSet, ok := p.queues[notif.Channel]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
extra := []byte(notif.Extra)
|
||||
for _, q := range queues {
|
||||
for q := range qSet.m {
|
||||
q.enqueue(extra)
|
||||
}
|
||||
}
|
||||
@ -374,8 +436,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) {
|
||||
func (p *PGPubsub) recordReconnect() {
|
||||
p.qMu.Lock()
|
||||
defer p.qMu.Unlock()
|
||||
for _, listeners := range p.queues {
|
||||
for _, q := range listeners {
|
||||
for _, qSet := range p.queues {
|
||||
for q := range qSet.m {
|
||||
q.dropped()
|
||||
}
|
||||
}
|
||||
@ -590,8 +652,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
|
||||
p.qMu.Lock()
|
||||
events := len(p.queues)
|
||||
subs := 0
|
||||
for _, subscriberMap := range p.queues {
|
||||
subs += len(subscriberMap)
|
||||
for _, qSet := range p.queues {
|
||||
subs += len(qSet.m)
|
||||
}
|
||||
p.qMu.Unlock()
|
||||
metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs))
|
||||
@ -629,7 +691,7 @@ func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub {
|
||||
logger: logger,
|
||||
listenDone: make(chan struct{}),
|
||||
db: db,
|
||||
queues: make(map[string]map[uuid.UUID]*msgQueue),
|
||||
queues: make(map[string]*queueSet),
|
||||
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")),
|
||||
|
||||
publishesTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
|
Reference in New Issue
Block a user