mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
feat: pubsub reports dropped messages (#7660)
* Implementation; need linux tests Signed-off-by: Spike Curtis <spike@coder.com> * Pubsub with errors tests and fixes Signed-off-by: Spike Curtis <spike@coder.com> * Deal with test goroutines Signed-off-by: Spike Curtis <spike@coder.com> --------- Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
@ -22,7 +22,8 @@ import (
|
||||
// Super unlikely, but it happened. See: https://github.com/coder/coder/runs/5375197003
|
||||
var openPortMutex sync.Mutex
|
||||
|
||||
// Open creates a new PostgreSQL server using a Docker container.
|
||||
// Open creates a new PostgreSQL database instance. With DB_FROM environment variable set, it clones a database
|
||||
// from the provided template. With the environment variable unset, it creates a new Docker container running postgres.
|
||||
func Open() (string, func(), error) {
|
||||
if os.Getenv("DB_FROM") != "" {
|
||||
// In CI, creating a Docker container for each test is slow.
|
||||
@ -51,7 +52,12 @@ func Open() (string, func(), error) {
|
||||
// so cleaning up the container will clean up the database.
|
||||
}, nil
|
||||
}
|
||||
return OpenContainerized(0)
|
||||
}
|
||||
|
||||
// OpenContainerized creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic
|
||||
// to that port to the database. If port is zero, allocate a free port from the OS.
|
||||
func OpenContainerized(port int) (string, func(), error) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("create pool: %w", err)
|
||||
@ -63,12 +69,14 @@ func Open() (string, func(), error) {
|
||||
}
|
||||
|
||||
openPortMutex.Lock()
|
||||
// Pick an explicit port on the host to connect to 5432.
|
||||
// This is necessary so we can configure the port to only use ipv4.
|
||||
port, err := getFreePort()
|
||||
if err != nil {
|
||||
openPortMutex.Unlock()
|
||||
return "", nil, xerrors.Errorf("get free port: %w", err)
|
||||
if port == 0 {
|
||||
// Pick an explicit port on the host to connect to 5432.
|
||||
// This is necessary so we can configure the port to only use ipv4.
|
||||
port, err = getFreePort()
|
||||
if err != nil {
|
||||
openPortMutex.Unlock()
|
||||
return "", nil, xerrors.Errorf("get free port: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
resource, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
|
@ -15,29 +15,174 @@ import (
|
||||
// Listener represents a pubsub handler.
|
||||
type Listener func(ctx context.Context, message []byte)
|
||||
|
||||
// ListenerWithErr represents a pubsub handler that can also receive error
|
||||
// indications
|
||||
type ListenerWithErr func(ctx context.Context, message []byte, err error)
|
||||
|
||||
// ErrDroppedMessages is sent to ListenerWithErr if messages are dropped or
|
||||
// might have been dropped.
|
||||
var ErrDroppedMessages = xerrors.New("dropped messages")
|
||||
|
||||
// Pubsub is a generic interface for broadcasting and receiving messages.
|
||||
// Implementors should assume high-availability with the backing implementation.
|
||||
type Pubsub interface {
|
||||
Subscribe(event string, listener Listener) (cancel func(), err error)
|
||||
SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error)
|
||||
Publish(event string, message []byte) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// msgOrErr either contains a message or an error
|
||||
type msgOrErr struct {
|
||||
msg []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// msgQueue implements a fixed length queue with the ability to replace elements
|
||||
// after they are queued (but before they are dequeued).
|
||||
//
|
||||
// The purpose of this data structure is to build something that works a bit
|
||||
// like a golang channel, but if the queue is full, then we can replace the
|
||||
// last element with an error so that the subscriber can get notified that some
|
||||
// messages were dropped, all without blocking.
|
||||
type msgQueue struct {
|
||||
ctx context.Context
|
||||
cond *sync.Cond
|
||||
q [PubsubBufferSize]msgOrErr
|
||||
front int
|
||||
size int
|
||||
closed bool
|
||||
l Listener
|
||||
le ListenerWithErr
|
||||
}
|
||||
|
||||
func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue {
|
||||
if l == nil && le == nil {
|
||||
panic("l or le must be non-nil")
|
||||
}
|
||||
q := &msgQueue{
|
||||
ctx: ctx,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
l: l,
|
||||
le: le,
|
||||
}
|
||||
go q.run()
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *msgQueue) run() {
|
||||
for {
|
||||
// wait until there is something on the queue or we are closed
|
||||
q.cond.L.Lock()
|
||||
for q.size == 0 && !q.closed {
|
||||
q.cond.Wait()
|
||||
}
|
||||
if q.closed {
|
||||
q.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
item := q.q[q.front]
|
||||
q.front = (q.front + 1) % PubsubBufferSize
|
||||
q.size--
|
||||
q.cond.L.Unlock()
|
||||
|
||||
// process item without holding lock
|
||||
if item.err == nil {
|
||||
// real message
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
// unhittable
|
||||
continue
|
||||
}
|
||||
// if the listener wants errors, send it.
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *msgQueue) enqueue(msg []byte) {
|
||||
q.cond.L.Lock()
|
||||
defer q.cond.L.Unlock()
|
||||
|
||||
if q.size == PubsubBufferSize {
|
||||
// queue is full, so we're going to drop the msg we got called with.
|
||||
// We also need to record that messages are being dropped, which we
|
||||
// do at the last message in the queue. This potentially makes us
|
||||
// lose 2 messages instead of one, but it's more important at this
|
||||
// point to warn the subscriber that they're losing messages so they
|
||||
// can do something about it.
|
||||
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
|
||||
q.q[back].msg = nil
|
||||
q.q[back].err = ErrDroppedMessages
|
||||
return
|
||||
}
|
||||
// queue is not full, insert the message
|
||||
next := (q.front + q.size) % PubsubBufferSize
|
||||
q.q[next].msg = msg
|
||||
q.q[next].err = nil
|
||||
q.size++
|
||||
q.cond.Broadcast()
|
||||
}
|
||||
|
||||
func (q *msgQueue) close() {
|
||||
q.cond.L.Lock()
|
||||
defer q.cond.L.Unlock()
|
||||
defer q.cond.Broadcast()
|
||||
q.closed = true
|
||||
}
|
||||
|
||||
// dropped records an error in the queue that messages might have been dropped
|
||||
func (q *msgQueue) dropped() {
|
||||
q.cond.L.Lock()
|
||||
defer q.cond.L.Unlock()
|
||||
|
||||
if q.size == PubsubBufferSize {
|
||||
// queue is full, but we need to record that messages are being dropped,
|
||||
// which we do at the last message in the queue. This potentially drops
|
||||
// another message, but it's more important for the subscriber to know.
|
||||
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
|
||||
q.q[back].msg = nil
|
||||
q.q[back].err = ErrDroppedMessages
|
||||
return
|
||||
}
|
||||
// queue is not full, insert the error
|
||||
next := (q.front + q.size) % PubsubBufferSize
|
||||
q.q[next].msg = nil
|
||||
q.q[next].err = ErrDroppedMessages
|
||||
q.size++
|
||||
q.cond.Broadcast()
|
||||
}
|
||||
|
||||
// Pubsub implementation using PostgreSQL.
|
||||
type pgPubsub struct {
|
||||
ctx context.Context
|
||||
pgListener *pq.Listener
|
||||
db *sql.DB
|
||||
mut sync.Mutex
|
||||
listeners map[string]map[uuid.UUID]chan<- []byte
|
||||
queues map[string]map[uuid.UUID]*msgQueue
|
||||
}
|
||||
|
||||
// messageBufferSize is the maximum number of unhandled messages we will buffer
|
||||
// PubsubBufferSize is the maximum number of unhandled messages we will buffer
|
||||
// for a subscriber before dropping messages.
|
||||
const messageBufferSize = 2048
|
||||
const PubsubBufferSize = 2048
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(p.ctx, listener, nil))
|
||||
}
|
||||
|
||||
func (p *pgPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(p.ctx, nil, listener))
|
||||
}
|
||||
|
||||
func (p *pgPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
|
||||
@ -50,23 +195,20 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
|
||||
return nil, xerrors.Errorf("listen: %w", err)
|
||||
}
|
||||
|
||||
var eventListeners map[uuid.UUID]chan<- []byte
|
||||
var eventQs map[uuid.UUID]*msgQueue
|
||||
var ok bool
|
||||
if eventListeners, ok = p.listeners[event]; !ok {
|
||||
eventListeners = make(map[uuid.UUID]chan<- []byte)
|
||||
p.listeners[event] = eventListeners
|
||||
if eventQs, ok = p.queues[event]; !ok {
|
||||
eventQs = make(map[uuid.UUID]*msgQueue)
|
||||
p.queues[event] = eventQs
|
||||
}
|
||||
|
||||
ctx, cancelCallbacks := context.WithCancel(p.ctx)
|
||||
messages := make(chan []byte, messageBufferSize)
|
||||
go messagesToListener(ctx, messages, listener)
|
||||
id := uuid.New()
|
||||
eventListeners[id] = messages
|
||||
eventQs[id] = newQ
|
||||
return func() {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
cancelCallbacks()
|
||||
listeners := p.listeners[event]
|
||||
listeners := p.queues[event]
|
||||
q := listeners[id]
|
||||
q.close()
|
||||
delete(listeners, id)
|
||||
|
||||
if len(listeners) == 0 {
|
||||
@ -109,6 +251,7 @@ func (p *pgPubsub) listen(ctx context.Context) {
|
||||
}
|
||||
// A nil notification can be dispatched on reconnect.
|
||||
if notif == nil {
|
||||
p.recordReconnect()
|
||||
continue
|
||||
}
|
||||
p.listenReceive(notif)
|
||||
@ -118,19 +261,22 @@ func (p *pgPubsub) listen(ctx context.Context) {
|
||||
func (p *pgPubsub) listenReceive(notif *pq.Notification) {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
listeners, ok := p.listeners[notif.Channel]
|
||||
queues, ok := p.queues[notif.Channel]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
extra := []byte(notif.Extra)
|
||||
for _, listener := range listeners {
|
||||
select {
|
||||
case listener <- extra:
|
||||
// ok!
|
||||
default:
|
||||
// bad news, we dropped the event because the listener isn't
|
||||
// keeping up
|
||||
// TODO (spike): figure out a way to communicate this to the Listener
|
||||
for _, q := range queues {
|
||||
q.enqueue(extra)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pgPubsub) recordReconnect() {
|
||||
p.mut.Lock()
|
||||
defer p.mut.Unlock()
|
||||
for _, listeners := range p.queues {
|
||||
for _, q := range listeners {
|
||||
q.dropped()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -162,20 +308,9 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
|
||||
ctx: ctx,
|
||||
db: database,
|
||||
pgListener: listener,
|
||||
listeners: make(map[string]map[uuid.UUID]chan<- []byte),
|
||||
queues: make(map[string]map[uuid.UUID]*msgQueue),
|
||||
}
|
||||
go pgPubsub.listen(ctx)
|
||||
|
||||
return pgPubsub, nil
|
||||
}
|
||||
|
||||
func messagesToListener(ctx context.Context, messages <-chan []byte, listener Listener) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case m := <-messages:
|
||||
listener(ctx, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
140
coderd/database/pubsub_internal_test.go
Normal file
140
coderd/database/pubsub_internal_test.go
Normal file
@ -0,0 +1,140 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func Test_msgQueue_ListenerWithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
m := make(chan string)
|
||||
e := make(chan error)
|
||||
uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
|
||||
m <- string(msg)
|
||||
e <- err
|
||||
})
|
||||
defer uut.close()
|
||||
|
||||
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
|
||||
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
|
||||
// when we wrap around the end of the circular buffer. This tests that we correctly handle
|
||||
// the wrapping and aren't dequeueing misaligned data.
|
||||
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
|
||||
for j := 0; j < cycles; j++ {
|
||||
for i := 0; i < 4; i++ {
|
||||
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
|
||||
}
|
||||
uut.dropped()
|
||||
for i := 0; i < 4; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case msg := <-m:
|
||||
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case err := <-e:
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case msg := <-m:
|
||||
require.Equal(t, "", msg)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case err := <-e:
|
||||
require.ErrorIs(t, err, ErrDroppedMessages)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_msgQueue_Listener(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
m := make(chan string)
|
||||
uut := newMsgQueue(ctx, func(ctx context.Context, msg []byte) {
|
||||
m <- string(msg)
|
||||
}, nil)
|
||||
defer uut.close()
|
||||
|
||||
// We're going to enqueue 4 messages and an error in a loop -- that is, a cycle of 5.
|
||||
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
|
||||
// when we wrap around the end of the circular buffer. This tests that we correctly handle
|
||||
// the wrapping and aren't dequeueing misaligned data.
|
||||
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
|
||||
for j := 0; j < cycles; j++ {
|
||||
for i := 0; i < 4; i++ {
|
||||
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
|
||||
}
|
||||
uut.dropped()
|
||||
for i := 0; i < 4; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case msg := <-m:
|
||||
require.Equal(t, fmt.Sprintf("%d%d", j, i), msg)
|
||||
}
|
||||
}
|
||||
// Listener skips over errors, so we only read out the 4 real messages.
|
||||
}
|
||||
}
|
||||
|
||||
func Test_msgQueue_Full(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
firstDequeue := make(chan struct{})
|
||||
allowRead := make(chan struct{})
|
||||
n := 0
|
||||
errors := make(chan error)
|
||||
uut := newMsgQueue(ctx, nil, func(ctx context.Context, msg []byte, err error) {
|
||||
if n == 0 {
|
||||
close(firstDequeue)
|
||||
}
|
||||
<-allowRead
|
||||
if err == nil {
|
||||
require.Equal(t, fmt.Sprintf("%d", n), string(msg))
|
||||
n++
|
||||
return
|
||||
}
|
||||
errors <- err
|
||||
})
|
||||
defer uut.close()
|
||||
|
||||
// we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks
|
||||
// but only after we've dequeued a message, and then another extra because we want to exceed
|
||||
// the capacity, not just reach it.
|
||||
for i := 0; i < PubsubBufferSize+2; i++ {
|
||||
uut.enqueue([]byte(fmt.Sprintf("%d", i)))
|
||||
// ensure the first dequeue has happened before proceeding, so that this function isn't racing
|
||||
// against the goroutine that dequeues items.
|
||||
<-firstDequeue
|
||||
}
|
||||
close(allowRead)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case err := <-errors:
|
||||
require.ErrorIs(t, err, ErrDroppedMessages)
|
||||
}
|
||||
// Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last
|
||||
// message we send doesn't get queued, AND, it bumps a message out of the queue to make room
|
||||
// for the error, so we read 2 less than we sent.
|
||||
require.Equal(t, PubsubBufferSize, n)
|
||||
}
|
@ -7,20 +7,43 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// genericListener is either a Listener or ListenerWithErr
|
||||
type genericListener struct {
|
||||
l Listener
|
||||
le ListenerWithErr
|
||||
}
|
||||
|
||||
func (g genericListener) send(ctx context.Context, message []byte) {
|
||||
if g.l != nil {
|
||||
g.l(ctx, message)
|
||||
}
|
||||
if g.le != nil {
|
||||
g.le(ctx, message, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// memoryPubsub is an in-memory Pubsub implementation.
|
||||
type memoryPubsub struct {
|
||||
mut sync.RWMutex
|
||||
listeners map[string]map[uuid.UUID]Listener
|
||||
listeners map[string]map[uuid.UUID]genericListener
|
||||
}
|
||||
|
||||
func (m *memoryPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return m.subscribeGeneric(event, genericListener{l: listener})
|
||||
}
|
||||
|
||||
func (m *memoryPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) {
|
||||
return m.subscribeGeneric(event, genericListener{le: listener})
|
||||
}
|
||||
|
||||
func (m *memoryPubsub) subscribeGeneric(event string, listener genericListener) (cancel func(), err error) {
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
|
||||
var listeners map[uuid.UUID]Listener
|
||||
var listeners map[uuid.UUID]genericListener
|
||||
var ok bool
|
||||
if listeners, ok = m.listeners[event]; !ok {
|
||||
listeners = map[uuid.UUID]Listener{}
|
||||
listeners = map[uuid.UUID]genericListener{}
|
||||
m.listeners[event] = listeners
|
||||
}
|
||||
var id uuid.UUID
|
||||
@ -52,7 +75,7 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
|
||||
listener := listener
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
listener(context.Background(), message)
|
||||
listener.send(context.Background(), message)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
@ -66,6 +89,6 @@ func (*memoryPubsub) Close() error {
|
||||
|
||||
func NewPubsubInMemory() Pubsub {
|
||||
return &memoryPubsub{
|
||||
listeners: make(map[string]map[uuid.UUID]Listener),
|
||||
listeners: make(map[string]map[uuid.UUID]genericListener),
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ import (
|
||||
func TestPubsubMemory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Memory", func(t *testing.T) {
|
||||
t.Run("Legacy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
@ -32,4 +32,25 @@ func TestPubsubMemory(t *testing.T) {
|
||||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
|
||||
t.Run("WithErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
event := "test"
|
||||
data := "testing"
|
||||
messageChannel := make(chan []byte)
|
||||
cancelFunc, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, message []byte, err error) {
|
||||
assert.NoError(t, err) // memory pubsub never sends errors.
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
go func() {
|
||||
err = pubsub.Publish(event, []byte(data))
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
message := <-messageChannel
|
||||
assert.Equal(t, string(message), data)
|
||||
})
|
||||
}
|
||||
|
@ -7,16 +7,19 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
// nolint:tparallel,paralleltest
|
||||
@ -90,7 +93,7 @@ func TestPubsub_ordering(t *testing.T) {
|
||||
defer pubsub.Close()
|
||||
event := "test"
|
||||
messageChannel := make(chan []byte, 100)
|
||||
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
cancelSub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
// sleep a random amount of time to simulate handlers taking different amount of time
|
||||
// to process, depending on the message
|
||||
// nolint: gosec
|
||||
@ -99,7 +102,7 @@ func TestPubsub_ordering(t *testing.T) {
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelFunc()
|
||||
defer cancelSub()
|
||||
for i := 0; i < 100; i++ {
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
|
||||
assert.NoError(t, err)
|
||||
@ -113,3 +116,143 @@ func TestPubsub_ordering(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubsub_Disconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
// we always use a Docker container for this test, even in CI, since we need to be able to kill
|
||||
// postgres and bring it back on the same port.
|
||||
connectionURL, closePg, err := postgres.OpenContainerized(0)
|
||||
require.NoError(t, err)
|
||||
defer closePg()
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancelFunc()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
event := "test"
|
||||
|
||||
// buffer responses so that when the test completes, goroutines don't get blocked & leak
|
||||
errors := make(chan error, database.PubsubBufferSize)
|
||||
messages := make(chan string, database.PubsubBufferSize)
|
||||
readOne := func() (m string, e error) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case m = <-messages:
|
||||
// OK
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
case e = <-errors:
|
||||
// OK
|
||||
}
|
||||
return m, e
|
||||
}
|
||||
|
||||
cancelSub, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
|
||||
messages <- string(msg)
|
||||
errors <- err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancelSub()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
// make sure we're getting at least one message.
|
||||
m, err := readOne()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "0", m)
|
||||
|
||||
closePg()
|
||||
// write some more messages until we hit an error
|
||||
j := 100
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
j++
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(testutil.IntervalFast)
|
||||
}
|
||||
|
||||
// restart postgres on the same port --- since we only use LISTEN/NOTIFY it doesn't
|
||||
// matter that the new postgres doesn't have any persisted state from before.
|
||||
u, err := url.Parse(connectionURL)
|
||||
require.NoError(t, err)
|
||||
addr, err := net.ResolveTCPAddr("tcp", u.Host)
|
||||
require.NoError(t, err)
|
||||
newURL, closeNewPg, err := postgres.OpenContainerized(addr.Port)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, connectionURL, newURL)
|
||||
defer closeNewPg()
|
||||
|
||||
// now write messages until we DON'T hit an error -- pubsub is back up.
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out")
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
j++
|
||||
time.Sleep(testutil.IntervalFast)
|
||||
}
|
||||
// any message k or higher comes from after the restart.
|
||||
k := j
|
||||
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB
|
||||
// reconnect
|
||||
require.Less(t, k, database.PubsubBufferSize, "exceeded buffer")
|
||||
|
||||
// We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As
|
||||
// soon as we see k or higher we know we're getting messages after the restart.
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
_ = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
j++
|
||||
time.Sleep(testutil.IntervalFast)
|
||||
}
|
||||
}()
|
||||
|
||||
gotDroppedErr := false
|
||||
for {
|
||||
m, err := readOne()
|
||||
if xerrors.Is(err, database.ErrDroppedMessages) {
|
||||
gotDroppedErr = true
|
||||
continue
|
||||
}
|
||||
require.NoError(t, err, "should only get ErrDroppedMessages")
|
||||
l, err := strconv.Atoi(m)
|
||||
require.NoError(t, err)
|
||||
if l >= k {
|
||||
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than
|
||||
// DB reconnect
|
||||
require.Less(t, l, database.PubsubBufferSize, "exceeded buffer")
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, gotDroppedErr)
|
||||
}
|
||||
|
Reference in New Issue
Block a user