fix(coderd/database): improve pubsub closure and context cancellation (#7993)

This commit is contained in:
Mathias Fredriksson
2023-06-13 15:19:56 +03:00
committed by GitHub
parent aba5cb8377
commit 518300a26c
2 changed files with 108 additions and 9 deletions

View File

@ -163,6 +163,8 @@ func (q *msgQueue) dropped() {
// Pubsub implementation using PostgreSQL. // Pubsub implementation using PostgreSQL.
type pgPubsub struct { type pgPubsub struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
listenDone chan struct{}
pgListener *pq.Listener pgListener *pq.Listener
db *sql.DB db *sql.DB
mut sync.Mutex mut sync.Mutex
@ -228,7 +230,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't // This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
// support the first parameter being a prepared statement. // support the first parameter being a prepared statement.
//nolint:gosec //nolint:gosec
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) _, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil { if err != nil {
return xerrors.Errorf("exec pg_notify: %w", err) return xerrors.Errorf("exec pg_notify: %w", err)
} }
@ -237,19 +239,24 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
// Close closes the pubsub instance. // Close closes the pubsub instance.
func (p *pgPubsub) Close() error { func (p *pgPubsub) Close() error {
return p.pgListener.Close() p.cancel()
err := p.pgListener.Close()
<-p.listenDone
return err
} }
// listen begins receiving messages on the pq listener. // listen begins receiving messages on the pq listener.
func (p *pgPubsub) listen(ctx context.Context) { func (p *pgPubsub) listen() {
defer close(p.listenDone)
defer p.pgListener.Close()
var ( var (
notif *pq.Notification notif *pq.Notification
ok bool ok bool
) )
defer p.pgListener.Close()
for { for {
select { select {
case <-ctx.Done(): case <-p.ctx.Done():
return return
case notif, ok = <-p.pgListener.Notify: case notif, ok = <-p.pgListener.Notify:
if !ok { if !ok {
@ -292,7 +299,7 @@ func (p *pgPubsub) recordReconnect() {
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) { func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq. // Creates a new listener using pq.
errCh := make(chan error) errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) { listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) {
// This callback gets events whenever the connection state changes. // This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed. // Don't send if the errChannel has already been closed.
select { select {
@ -306,18 +313,25 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
select { select {
case err := <-errCh: case err := <-errCh:
if err != nil { if err != nil {
_ = listener.Close()
return nil, xerrors.Errorf("create pq listener: %w", err) return nil, xerrors.Errorf("create pq listener: %w", err)
} }
case <-ctx.Done(): case <-ctx.Done():
_ = listener.Close()
return nil, ctx.Err() return nil, ctx.Err()
} }
// Start a new context that will be canceled when the pubsub is closed.
ctx, cancel := context.WithCancel(context.Background())
pgPubsub := &pgPubsub{ pgPubsub := &pgPubsub{
ctx: ctx, ctx: ctx,
cancel: cancel,
listenDone: make(chan struct{}),
db: database, db: database,
pgListener: listener, pgListener: listener,
queues: make(map[string]map[uuid.UUID]*msgQueue), queues: make(map[string]map[uuid.UUID]*msgQueue),
} }
go pgPubsub.listen(ctx) go pgPubsub.listen()
return pgPubsub, nil return pgPubsub, nil
} }

View File

@ -45,11 +45,11 @@ func TestPubsub(t *testing.T) {
event := "test" event := "test"
data := "testing" data := "testing"
messageChannel := make(chan []byte) messageChannel := make(chan []byte)
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) { unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
messageChannel <- message messageChannel <- message
}) })
require.NoError(t, err) require.NoError(t, err)
defer cancelFunc() defer unsub()
go func() { go func() {
err = pubsub.Publish(event, []byte(data)) err = pubsub.Publish(event, []byte(data))
assert.NoError(t, err) assert.NoError(t, err)
@ -72,6 +72,91 @@ func TestPubsub(t *testing.T) {
defer pubsub.Close() defer pubsub.Close()
cancelFunc() cancelFunc()
}) })
t.Run("NotClosedOnCancelContext", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
// Provided context must only be active during NewPubsub, not after.
cancel()
event := "test"
data := "testing"
messageChannel := make(chan []byte)
unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) {
messageChannel <- message
})
require.NoError(t, err)
defer unsub()
go func() {
err = pubsub.Publish(event, []byte(data))
assert.NoError(t, err)
}()
message := <-messageChannel
assert.Equal(t, string(message), data)
})
t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
connectionURL, closePg, err := postgres.Open()
require.NoError(t, err)
defer closePg()
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
done := make(chan struct{})
called := make(chan struct{})
unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) {
defer close(done)
select {
case <-subCtx.Done():
assert.Fail(t, "context should not be canceled")
default:
}
close(called)
select {
case <-subCtx.Done():
case <-ctx.Done():
assert.Fail(t, "timeout waiting for sub context to be canceled")
}
})
require.NoError(t, err)
defer unsub()
go func() {
err := pubsub.Publish(event, nil)
assert.NoError(t, err)
}()
select {
case <-called:
case <-ctx.Done():
require.Fail(t, "timeout waiting for handler to be called")
}
err = pubsub.Close()
require.NoError(t, err)
select {
case <-done:
case <-ctx.Done():
require.Fail(t, "timeout waiting for handler to finish")
}
})
} }
func TestPubsub_ordering(t *testing.T) { func TestPubsub_ordering(t *testing.T) {