mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
fix: pubsub ordering (#7404)
* fix: pubsub sends messages in order Signed-off-by: Spike Curtis <spike@coder.com> * Drop messages rather than block Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
@ -25,12 +25,17 @@ type Pubsub interface {
|
|||||||
|
|
||||||
// Pubsub implementation using PostgreSQL.
|
// Pubsub implementation using PostgreSQL.
|
||||||
type pgPubsub struct {
|
type pgPubsub struct {
|
||||||
|
ctx context.Context
|
||||||
pgListener *pq.Listener
|
pgListener *pq.Listener
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
mut sync.Mutex
|
mut sync.Mutex
|
||||||
listeners map[string]map[uuid.UUID]Listener
|
listeners map[string]map[uuid.UUID]chan<- []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// messageBufferSize is the maximum number of unhandled messages we will buffer
|
||||||
|
// for a subscriber before dropping messages.
|
||||||
|
const messageBufferSize = 2048
|
||||||
|
|
||||||
// Subscribe calls the listener when an event matching the name is received.
|
// Subscribe calls the listener when an event matching the name is received.
|
||||||
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||||
p.mut.Lock()
|
p.mut.Lock()
|
||||||
@ -45,25 +50,22 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
|
|||||||
return nil, xerrors.Errorf("listen: %w", err)
|
return nil, xerrors.Errorf("listen: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var eventListeners map[uuid.UUID]Listener
|
var eventListeners map[uuid.UUID]chan<- []byte
|
||||||
var ok bool
|
var ok bool
|
||||||
if eventListeners, ok = p.listeners[event]; !ok {
|
if eventListeners, ok = p.listeners[event]; !ok {
|
||||||
eventListeners = map[uuid.UUID]Listener{}
|
eventListeners = make(map[uuid.UUID]chan<- []byte)
|
||||||
p.listeners[event] = eventListeners
|
p.listeners[event] = eventListeners
|
||||||
}
|
}
|
||||||
|
|
||||||
var id uuid.UUID
|
ctx, cancelCallbacks := context.WithCancel(p.ctx)
|
||||||
for {
|
messages := make(chan []byte, messageBufferSize)
|
||||||
id = uuid.New()
|
go messagesToListener(ctx, messages, listener)
|
||||||
if _, ok = eventListeners[id]; !ok {
|
id := uuid.New()
|
||||||
break
|
eventListeners[id] = messages
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
eventListeners[id] = listener
|
|
||||||
return func() {
|
return func() {
|
||||||
p.mut.Lock()
|
p.mut.Lock()
|
||||||
defer p.mut.Unlock()
|
defer p.mut.Unlock()
|
||||||
|
cancelCallbacks()
|
||||||
listeners := p.listeners[event]
|
listeners := p.listeners[event]
|
||||||
delete(listeners, id)
|
delete(listeners, id)
|
||||||
|
|
||||||
@ -109,11 +111,11 @@ func (p *pgPubsub) listen(ctx context.Context) {
|
|||||||
if notif == nil {
|
if notif == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
p.listenReceive(ctx, notif)
|
p.listenReceive(notif)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
|
func (p *pgPubsub) listenReceive(notif *pq.Notification) {
|
||||||
p.mut.Lock()
|
p.mut.Lock()
|
||||||
defer p.mut.Unlock()
|
defer p.mut.Unlock()
|
||||||
listeners, ok := p.listeners[notif.Channel]
|
listeners, ok := p.listeners[notif.Channel]
|
||||||
@ -122,7 +124,14 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
|
|||||||
}
|
}
|
||||||
extra := []byte(notif.Extra)
|
extra := []byte(notif.Extra)
|
||||||
for _, listener := range listeners {
|
for _, listener := range listeners {
|
||||||
go listener(ctx, extra)
|
select {
|
||||||
|
case listener <- extra:
|
||||||
|
// ok!
|
||||||
|
default:
|
||||||
|
// bad news, we dropped the event because the listener isn't
|
||||||
|
// keeping up
|
||||||
|
// TODO (spike): figure out a way to communicate this to the Listener
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,11 +159,23 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
|
|||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
pgPubsub := &pgPubsub{
|
pgPubsub := &pgPubsub{
|
||||||
|
ctx: ctx,
|
||||||
db: database,
|
db: database,
|
||||||
pgListener: listener,
|
pgListener: listener,
|
||||||
listeners: make(map[string]map[uuid.UUID]Listener),
|
listeners: make(map[string]map[uuid.UUID]chan<- []byte),
|
||||||
}
|
}
|
||||||
go pgPubsub.listen(ctx)
|
go pgPubsub.listen(ctx)
|
||||||
|
|
||||||
return pgPubsub, nil
|
return pgPubsub, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case m := <-messages:
|
||||||
|
listener(ctx, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -5,7 +5,12 @@ package database_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/coder/testutil"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -67,3 +72,44 @@ func TestPubsub(t *testing.T) {
|
|||||||
cancelFunc()
|
cancelFunc()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPubsub_ordering(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
connectionURL, closePg, err := postgres.Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closePg()
|
||||||
|
db, err := sql.Open("postgres", connectionURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pubsub.Close()
|
||||||
|
event := "test"
|
||||||
|
messageChannel := make(chan []byte, 100)
|
||||||
|
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||||
|
// sleep a random amount of time to simulate handlers taking different amount of time
|
||||||
|
// to process, depending on the message
|
||||||
|
// nolint: gosec
|
||||||
|
n := rand.Intn(100)
|
||||||
|
time.Sleep(time.Duration(n) * time.Millisecond)
|
||||||
|
messageChannel <- message
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer cancelFunc()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
select {
|
||||||
|
case <-time.After(testutil.WaitShort):
|
||||||
|
t.Fatalf("timed out waiting for message %d", i)
|
||||||
|
case message := <-messageChannel:
|
||||||
|
assert.Equal(t, fmt.Sprintf("%d", i), string(message))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user