fix: pubsub ordering (#7404)

* fix: pubsub sends messages in order

Signed-off-by: Spike Curtis <spike@coder.com>

* Drop messages rather than block

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis
2023-05-05 09:39:07 +04:00
committed by GitHub
parent 667d9a7557
commit a6a44896bf
2 changed files with 83 additions and 16 deletions

View File

@ -25,12 +25,17 @@ type Pubsub interface {
// Pubsub implementation using PostgreSQL. // Pubsub implementation using PostgreSQL.
type pgPubsub struct { type pgPubsub struct {
ctx context.Context
pgListener *pq.Listener pgListener *pq.Listener
db *sql.DB db *sql.DB
mut sync.Mutex mut sync.Mutex
listeners map[string]map[uuid.UUID]Listener listeners map[string]map[uuid.UUID]chan<- []byte
} }
// messageBufferSize is the maximum number of unhandled messages we will buffer
// for a subscriber before dropping messages.
const messageBufferSize = 2048
// Subscribe calls the listener when an event matching the name is received. // Subscribe calls the listener when an event matching the name is received.
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
p.mut.Lock() p.mut.Lock()
@ -45,25 +50,22 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
return nil, xerrors.Errorf("listen: %w", err) return nil, xerrors.Errorf("listen: %w", err)
} }
var eventListeners map[uuid.UUID]Listener var eventListeners map[uuid.UUID]chan<- []byte
var ok bool var ok bool
if eventListeners, ok = p.listeners[event]; !ok { if eventListeners, ok = p.listeners[event]; !ok {
eventListeners = map[uuid.UUID]Listener{} eventListeners = make(map[uuid.UUID]chan<- []byte)
p.listeners[event] = eventListeners p.listeners[event] = eventListeners
} }
var id uuid.UUID ctx, cancelCallbacks := context.WithCancel(p.ctx)
for { messages := make(chan []byte, messageBufferSize)
id = uuid.New() go messagesToListener(ctx, messages, listener)
if _, ok = eventListeners[id]; !ok { id := uuid.New()
break eventListeners[id] = messages
}
}
eventListeners[id] = listener
return func() { return func() {
p.mut.Lock() p.mut.Lock()
defer p.mut.Unlock() defer p.mut.Unlock()
cancelCallbacks()
listeners := p.listeners[event] listeners := p.listeners[event]
delete(listeners, id) delete(listeners, id)
@ -109,11 +111,11 @@ func (p *pgPubsub) listen(ctx context.Context) {
if notif == nil { if notif == nil {
continue continue
} }
p.listenReceive(ctx, notif) p.listenReceive(notif)
} }
} }
func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) { func (p *pgPubsub) listenReceive(notif *pq.Notification) {
p.mut.Lock() p.mut.Lock()
defer p.mut.Unlock() defer p.mut.Unlock()
listeners, ok := p.listeners[notif.Channel] listeners, ok := p.listeners[notif.Channel]
@ -122,7 +124,14 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
} }
extra := []byte(notif.Extra) extra := []byte(notif.Extra)
for _, listener := range listeners { for _, listener := range listeners {
go listener(ctx, extra) select {
case listener <- extra:
// ok!
default:
// bad news, we dropped the event because the listener isn't
// keeping up
// TODO (spike): figure out a way to communicate this to the Listener
}
} }
} }
@ -150,11 +159,23 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
return nil, ctx.Err() return nil, ctx.Err()
} }
pgPubsub := &pgPubsub{ pgPubsub := &pgPubsub{
ctx: ctx,
db: database, db: database,
pgListener: listener, pgListener: listener,
listeners: make(map[string]map[uuid.UUID]Listener), listeners: make(map[string]map[uuid.UUID]chan<- []byte),
} }
go pgPubsub.listen(ctx) go pgPubsub.listen(ctx)
return pgPubsub, nil return pgPubsub, nil
} }
func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) {
for {
select {
case <-ctx.Done():
return
case m := <-messages:
listener(ctx, m)
}
}
}

View File

@ -5,7 +5,12 @@ package database_test
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"math/rand"
"testing" "testing"
"time"
"github.com/coder/coder/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -67,3 +72,44 @@ func TestPubsub(t *testing.T) {
cancelFunc() cancelFunc()
}) })
} }
func TestPubsub_ordering(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
messageChannel := make(chan []byte, 100)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
// sleep a random amount of time to simulate handlers taking different amount of time
// to process, depending on the message
// nolint: gosec
n := rand.Intn(100)
time.Sleep(time.Duration(n) * time.Millisecond)
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
for i := 0; i < 100; i++ {
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
assert.NoError(t, err)
}
for i := 0; i < 100; i++ {
select {
case <-time.After(testutil.WaitShort):
t.Fatalf("timed out waiting for message %d", i)
case message := <-messageChannel:
assert.Equal(t, fmt.Sprintf("%d", i), string(message))
}
}
}