feat: pubsub reports dropped messages (#7660)

* Implementation; need linux tests

Signed-off-by: Spike Curtis <spike@coder.com>

* Pubsub with errors tests and fixes

Signed-off-by: Spike Curtis <spike@coder.com>

* Deal with test goroutines

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis
2023-05-25 10:22:30 +04:00
committed by GitHub
parent 6a1e7ee1d0
commit 67cc196c92
6 changed files with 522 additions and 52 deletions

View File

@ -22,7 +22,8 @@ import (
// Super unlikely, but it happened. See: https://github.com/coder/coder/runs/5375197003
var openPortMutex sync.Mutex
// Open creates a new PostgreSQL server using a Docker container.
// Open creates a new PostgreSQL database instance. With DB_FROM environment variable set, it clones a database
// from the provided template. With the environment variable unset, it creates a new Docker container running postgres.
func Open() (string, func(), error) {
if os.Getenv("DB_FROM") != "" {
// In CI, creating a Docker container for each test is slow.
@ -51,7 +52,12 @@ func Open() (string, func(), error) {
// so cleaning up the container will clean up the database.
}, nil
}
return OpenContainerized(0)
}
// OpenContainerized creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic
// to that port to the database. If port is zero, allocate a free port from the OS.
func OpenContainerized(port int) (string, func(), error) {
pool, err := dockertest.NewPool("")
if err != nil {
return "", nil, xerrors.Errorf("create pool: %w", err)
@ -63,12 +69,14 @@ func Open() (string, func(), error) {
}
openPortMutex.Lock()
// Pick an explicit port on the host to connect to 5432.
// This is necessary so we can configure the port to only use ipv4.
port, err := getFreePort()
if err != nil {
openPortMutex.Unlock()
return "", nil, xerrors.Errorf("get free port: %w", err)
if port == 0 {
// Pick an explicit port on the host to connect to 5432.
// This is necessary so we can configure the port to only use ipv4.
port, err = getFreePort()
if err != nil {
openPortMutex.Unlock()
return "", nil, xerrors.Errorf("get free port: %w", err)
}
}
resource, err := pool.RunWithOptions(&dockertest.RunOptions{

View File

@ -15,29 +15,174 @@ import (
// Listener represents a pubsub handler.
type Listener func(ctx context.Context, message []byte)
// ListenerWithErr represents a pubsub handler that can also receive error
// indications
type ListenerWithErr func(ctx context.Context, message []byte, err error)
// ErrDroppedMessages is sent to ListenerWithErr if messages are dropped or
// might have been dropped.
var ErrDroppedMessages = xerrors.New("dropped messages")
// Pubsub is a generic interface for broadcasting and receiving messages.
// Implementors should assume high-availability with the backing implementation.
type Pubsub interface {
Subscribe(event string, listener Listener) (cancel func(), err error)
SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error)
Publish(event string, message []byte) error
Close() error
}
// msgOrErr either contains a message or an error
type msgOrErr struct {
msg []byte
err error
}
// msgQueue implements a fixed length queue with the ability to replace elements
// after they are queued (but before they are dequeued).
//
// The purpose of this data structure is to build something that works a bit
// like a golang channel, but if the queue is full, then we can replace the
// last element with an error so that the subscriber can get notified that some
// messages were dropped, all without blocking.
type msgQueue struct {
ctx context.Context
cond *sync.Cond
q [PubsubBufferSize]msgOrErr
front int
size int
closed bool
l Listener
le ListenerWithErr
}
func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue {
if l == nil && le == nil {
panic("l or le must be non-nil")
}
q := &msgQueue{
ctx: ctx,
cond: sync.NewCond(&sync.Mutex{}),
l: l,
le: le,
}
go q.run()
return q
}
func (q *msgQueue) run() {
for {
// wait until there is something on the queue or we are closed
q.cond.L.Lock()
for q.size == 0 && !q.closed {
q.cond.Wait()
}
if q.closed {
q.cond.L.Unlock()
return
}
item := q.q[q.front]
q.front = (q.front + 1) % PubsubBufferSize
q.size--
q.cond.L.Unlock()
// process item without holding lock
if item.err == nil {
// real message
if q.l != nil {
q.l(q.ctx, item.msg)
continue
}
if q.le != nil {
q.le(q.ctx, item.msg, nil)
continue
}
// unhittable
continue
}
// if the listener wants errors, send it.
if q.le != nil {
q.le(q.ctx, nil, item.err)
}
}
}
func (q *msgQueue) enqueue(msg []byte) {
q.cond.L.Lock()
defer q.cond.L.Unlock()
if q.size == PubsubBufferSize {
// queue is full, so we're going to drop the msg we got called with.
// We also need to record that messages are being dropped, which we
// do at the last message in the queue. This potentially makes us
// lose 2 messages instead of one, but it's more important at this
// point to warn the subscriber that they're losing messages so they
// can do something about it.
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
q.q[back].msg = nil
q.q[back].err = ErrDroppedMessages
return
}
// queue is not full, insert the message
next := (q.front + q.size) % PubsubBufferSize
q.q[next].msg = msg
q.q[next].err = nil
q.size++
q.cond.Broadcast()
}
func (q *msgQueue) close() {
q.cond.L.Lock()
defer q.cond.L.Unlock()
defer q.cond.Broadcast()
q.closed = true
}
// dropped records an error in the queue that messages might have been dropped
func (q *msgQueue) dropped() {
q.cond.L.Lock()
defer q.cond.L.Unlock()
if q.size == PubsubBufferSize {
// queue is full, but we need to record that messages are being dropped,
// which we do at the last message in the queue. This potentially drops
// another message, but it's more important for the subscriber to know.
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
q.q[back].msg = nil
q.q[back].err = ErrDroppedMessages
return
}
// queue is not full, insert the error
next := (q.front + q.size) % PubsubBufferSize
q.q[next].msg = nil
q.q[next].err = ErrDroppedMessages
q.size++
q.cond.Broadcast()
}
// Pubsub implementation using PostgreSQL.
type pgPubsub struct {
ctx context.Context
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
listeners map[string]map[uuid.UUID]chan<- []byte
queues map[string]map[uuid.UUID]*msgQueue
}
// messageBufferSize is the maximum number of unhandled messages we will buffer
// PubsubBufferSize is the maximum number of unhandled messages we will buffer
// for a subscriber before dropping messages.
const messageBufferSize = 2048
const PubsubBufferSize = 2048
// Subscribe calls the listener when an event matching the name is received.
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil))
}
func (p *pgPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener))
}
func (p *pgPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) {
p.mut.Lock()
defer p.mut.Unlock()
@ -50,23 +195,20 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
return nil, xerrors.Errorf("listen: %w", err)
}
var eventListeners map[uuid.UUID]chan<- []byte
var eventQs map[uuid.UUID]*msgQueue
var ok bool
if eventListeners, ok = p.listeners[event]; !ok {
eventListeners = make(map[uuid.UUID]chan<- []byte)
p.listeners[event] = eventListeners
if eventQs, ok = p.queues[event]; !ok {
eventQs = make(map[uuid.UUID]*msgQueue)
p.queues[event] = eventQs
}
ctx, cancelCallbacks := context.WithCancel(p.ctx)
messages := make(chan []byte, messageBufferSize)
go messagesToListener(ctx, messages, listener)
id := uuid.New()
eventListeners[id] = messages
eventQs[id] = newQ
return func() {
p.mut.Lock()
defer p.mut.Unlock()
cancelCallbacks()
listeners := p.listeners[event]
listeners := p.queues[event]
q := listeners[id]
q.close()
delete(listeners, id)
if len(listeners) == 0 {
@ -109,6 +251,7 @@ func (p *pgPubsub) listen(ctx context.Context) {
}
// A nil notification can be dispatched on reconnect.
if notif == nil {
p.recordReconnect()
continue
}
p.listenReceive(notif)
@ -118,19 +261,22 @@ func (p *pgPubsub) listen(ctx context.Context) {
func (p *pgPubsub) listenReceive(notif *pq.Notification) {
p.mut.Lock()
defer p.mut.Unlock()
listeners, ok := p.listeners[notif.Channel]
queues, ok := p.queues[notif.Channel]
if !ok {
return
}
extra := []byte(notif.Extra)
for _, listener := range listeners {
select {
case listener <- extra:
// ok!
default:
// bad news, we dropped the event because the listener isn't
// keeping up
// TODO (spike): figure out a way to communicate this to the Listener
for _, q := range queues {
q.enqueue(extra)
}
}
func (p *pgPubsub) recordReconnect() {
p.mut.Lock()
defer p.mut.Unlock()
for _, listeners := range p.queues {
for _, q := range listeners {
q.dropped()
}
}
}
@ -162,20 +308,9 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
ctx: ctx,
db: database,
pgListener: listener,
listeners: make(map[string]map[uuid.UUID]chan<- []byte),
queues: make(map[string]map[uuid.UUID]*msgQueue),
}
go pgPubsub.listen(ctx)
return pgPubsub, nil
}
func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) {
for {
select {
case <-ctx.Done():
return
case m := <-messages:
listener(ctx, m)
}
}
}

View File

@ -0,0 +1,140 @@
package database
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/testutil"
)
func Test_msgQueue_ListenerWithError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
m := make(chan string)
e := make(chan error)
uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
m <- string(msg)
e <- err
})
defer uut.close()
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
}
uut.dropped()
for i := 0; i < 4; i++ {
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-e:
require.NoError(t, err)
}
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, "", msg)
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-e:
require.ErrorIs(t, err, ErrDroppedMessages)
}
}
}
func Test_msgQueue_Listener(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
m := make(chan string)
uut := newMsgQueue(ctx, func(ctx context.Context, msg []byte) {
m <- string(msg)
}, nil)
defer uut.close()
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
}
uut.dropped()
for i := 0; i < 4; i++ {
select {
case <-ctx.Done():
t.Fatal("timed out")
case msg := <-m:
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
}
}
// Listener skips over errors, so we only read out the 4 real messages.
}
}
func Test_msgQueue_Full(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
firstDequeue := make(chan struct{})
allowRead := make(chan struct{})
n := 0
errors := make(chan error)
uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
if n == 0 {
close(firstDequeue)
}
<-allowRead
if err == nil {
require.Equal(t, fmt.Sprintf("%d", n), string(msg))
n++
return
}
errors <- err
})
defer uut.close()
// we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks
// but only after we've dequeued a message, and then another extra because we want to exceed
// the capacity, not just reach it.
for i := 0; i < PubsubBufferSize+2; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d", i)))
// ensure the first dequeue has happened before proceeding, so that this function isn't racing
// against the goroutine that dequeues items.
<-firstDequeue
}
close(allowRead)
select {
case <-ctx.Done():
t.Fatal("timed out")
case err := <-errors:
require.ErrorIs(t, err, ErrDroppedMessages)
}
// Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last
// message we send doesn't get queued, AND, it bumps a message out of the queue to make room
// for the error, so we read 2 less than we sent.
require.Equal(t, PubsubBufferSize, n)
}

View File

@ -7,20 +7,43 @@ import (
"github.com/google/uuid"
)
// genericListener is either a Listener or ListenerWithErr
type genericListener struct {
l Listener
le ListenerWithErr
}
func (g genericListener) send(ctx context.Context, message []byte) {
if g.l != nil {
g.l(ctx, message)
}
if g.le != nil {
g.le(ctx, message, nil)
}
}
// memoryPubsub is an in-memory Pubsub implementation.
type memoryPubsub struct {
mut sync.RWMutex
listeners map[string]map[uuid.UUID]Listener
listeners map[string]map[uuid.UUID]genericListener
}
func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
return m.subscribeGeneric(event, genericListener{l: listener})
}
func (m *memoryPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
return m.subscribeGeneric(event, genericListener{le: listener})
}
func (m *memoryPubsub) subscribeGeneric(event string, listener genericListener) (cancel func(), err error) {
m.mut.Lock()
defer m.mut.Unlock()
var listeners map[uuid.UUID]Listener
var listeners map[uuid.UUID]genericListener
var ok bool
if listeners, ok = m.listeners[event]; !ok {
listeners = map[uuid.UUID]Listener{}
listeners = map[uuid.UUID]genericListener{}
m.listeners[event] = listeners
}
var id uuid.UUID
@ -52,7 +75,7 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
listener := listener
go func() {
defer wg.Done()
listener(context.Background(), message)
listener.send(context.Background(), message)
}()
}
wg.Wait()
@ -66,6 +89,6 @@ func (*memoryPubsub) Close() error {
func NewPubsubInMemory() Pubsub {
return &memoryPubsub{
listeners: make(map[string]map[uuid.UUID]Listener),
listeners: make(map[string]map[uuid.UUID]genericListener),
}
}

View File

@ -13,7 +13,7 @@ import (
func TestPubsubMemory(t *testing.T) {
t.Parallel()
t.Run("Memory", func(t *testing.T) {
t.Run("Legacy", func(t *testing.T) {
t.Parallel()
pubsub := database.NewPubsubInMemory()
@ -32,4 +32,25 @@ func TestPubsubMemory(t *testing.T) {
message := <-messageChannel
assert.Equal(t, string(message), data)
})
t.Run("WithErr", func(t *testing.T) {
t.Parallel()
pubsub := database.NewPubsubInMemory()
event := "test"
data := "testing"
messageChannel := make(chan []byte)
cancelFunc, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, message []byte, err error) {
assert.NoError(t, err) // memory pubsub never sends errors.
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
go func() {
err = pubsub.Publish(event, []byte(data))
assert.NoError(t, err)
}()
message := <-messageChannel
assert.Equal(t, string(message), data)
})
}

View File

@ -7,16 +7,19 @@ import (
"database/sql"
"fmt"
"math/rand"
"net"
"net/url"
"strconv"
"testing"
"time"
"github.com/coder/coder/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/testutil"
)
// nolint:tparallel,paralleltest
@ -90,7 +93,7 @@ func TestPubsub_ordering(t *testing.T) {
defer pubsub.Close()
event := "test"
messageChannel := make(chan []byte, 100)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
cancelSub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
// sleep a random amount of time to simulate handlers taking different amount of time
// to process, depending on the message
// nolint: gosec
@ -99,7 +102,7 @@ func TestPubsub_ordering(t *testing.T) {
messageChannel <- message
})
require.NoError(t, err)
defer cancelFunc()
defer cancelSub()
for i := 0; i < 100; i++ {
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
assert.NoError(t, err)
@ -113,3 +116,143 @@ func TestPubsub_ordering(t *testing.T) {
}
}
}
func TestPubsub_Disconnect(t *testing.T) {
t.Parallel()
// we always use a Docker container for this test, even in CI, since we need to be able to kill
// postgres and bring it back on the same port.
connectionURL, closePg, err := postgres.OpenContainerized(0)
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
// buffer responses so that when the test completes, goroutines don't get blocked & leak
errors := make(chan error, database.PubsubBufferSize)
messages := make(chan string, database.PubsubBufferSize)
readOne := func() (m string, e error) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timed out")
case m = <-messages:
// OK
}
select {
case <-ctx.Done():
t.Fatal("timed out")
case e = <-errors:
// OK
}
return m, e
}
cancelSub, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
messages <- string(msg)
errors <- err
})
require.NoError(t, err)
defer cancelSub()
for i := 0; i < 100; i++ {
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
require.NoError(t, err)
}
// make sure we're getting at least one message.
m, err := readOne()
require.NoError(t, err)
require.Equal(t, "0", m)
closePg()
// write some more messages until we hit an error
j := 100
for {
select {
case <-ctx.Done():
t.Fatal("timed out")
default:
// ok
}
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
j++
if err != nil {
break
}
time.Sleep(testutil.IntervalFast)
}
// restart postgres on the same port --- since we only use LISTEN/NOTIFY it doesn't
// matter that the new postgres doesn't have any persisted state from before.
u, err := url.Parse(connectionURL)
require.NoError(t, err)
addr, err := net.ResolveTCPAddr("tcp", u.Host)
require.NoError(t, err)
newURL, closeNewPg, err := postgres.OpenContainerized(addr.Port)
require.NoError(t, err)
require.Equal(t, connectionURL, newURL)
defer closeNewPg()
// now write messages until we DON'T hit an error -- pubsub is back up.
for {
select {
case <-ctx.Done():
t.Fatal("timed out")
default:
// ok
}
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
if err == nil {
break
}
j++
time.Sleep(testutil.IntervalFast)
}
// any message k or higher comes from after the restart.
k := j
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB
// reconnect
require.Less(t, k, database.PubsubBufferSize, "exceeded buffer")
// We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As
// soon as we see k or higher we know we're getting messages after the restart.
go func() {
for {
select {
case <-ctx.Done():
return
default:
// ok
}
_ = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
j++
time.Sleep(testutil.IntervalFast)
}
}()
gotDroppedErr := false
for {
m, err := readOne()
if xerrors.Is(err, database.ErrDroppedMessages) {
gotDroppedErr = true
continue
}
require.NoError(t, err, "should only get ErrDroppedMessages")
l, err := strconv.Atoi(m)
require.NoError(t, err)
if l >= k {
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than
// DB reconnect
require.Less(t, l, database.PubsubBufferSize, "exceeded buffer")
break
}
}
require.True(t, gotDroppedErr)
}