mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
fix(coderd/database): improve pubsub closure and context cancellation (#7993)
This commit is contained in:
committed by
GitHub
parent
aba5cb8377
commit
518300a26c
@ -163,6 +163,8 @@ func (q *msgQueue) dropped() {
|
|||||||
// Pubsub implementation using PostgreSQL.
|
// Pubsub implementation using PostgreSQL.
|
||||||
type pgPubsub struct {
|
type pgPubsub struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
listenDone chan struct{}
|
||||||
pgListener *pq.Listener
|
pgListener *pq.Listener
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
mut sync.Mutex
|
mut sync.Mutex
|
||||||
@ -228,7 +230,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
|
|||||||
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
|
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
|
||||||
// support the first parameter being a prepared statement.
|
// support the first parameter being a prepared statement.
|
||||||
//nolint:gosec
|
//nolint:gosec
|
||||||
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
|
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("exec pg_notify: %w", err)
|
return xerrors.Errorf("exec pg_notify: %w", err)
|
||||||
}
|
}
|
||||||
@ -237,19 +239,24 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
|
|||||||
|
|
||||||
// Close closes the pubsub instance.
|
// Close closes the pubsub instance.
|
||||||
func (p *pgPubsub) Close() error {
|
func (p *pgPubsub) Close() error {
|
||||||
return p.pgListener.Close()
|
p.cancel()
|
||||||
|
err := p.pgListener.Close()
|
||||||
|
<-p.listenDone
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen begins receiving messages on the pq listener.
|
// listen begins receiving messages on the pq listener.
|
||||||
func (p *pgPubsub) listen(ctx context.Context) {
|
func (p *pgPubsub) listen() {
|
||||||
|
defer close(p.listenDone)
|
||||||
|
defer p.pgListener.Close()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
notif *pq.Notification
|
notif *pq.Notification
|
||||||
ok bool
|
ok bool
|
||||||
)
|
)
|
||||||
defer p.pgListener.Close()
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-p.ctx.Done():
|
||||||
return
|
return
|
||||||
case notif, ok = <-p.pgListener.Notify:
|
case notif, ok = <-p.pgListener.Notify:
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -292,7 +299,7 @@ func (p *pgPubsub) recordReconnect() {
|
|||||||
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
|
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
|
||||||
// Creates a new listener using pq.
|
// Creates a new listener using pq.
|
||||||
errCh := make(chan error)
|
errCh := make(chan error)
|
||||||
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) {
|
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) {
|
||||||
// This callback gets events whenever the connection state changes.
|
// This callback gets events whenever the connection state changes.
|
||||||
// Don't send if the errChannel has already been closed.
|
// Don't send if the errChannel has already been closed.
|
||||||
select {
|
select {
|
||||||
@ -306,18 +313,25 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
|
|||||||
select {
|
select {
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
_ = listener.Close()
|
||||||
return nil, xerrors.Errorf("create pq listener: %w", err)
|
return nil, xerrors.Errorf("create pq listener: %w", err)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
_ = listener.Close()
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start a new context that will be canceled when the pubsub is closed.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
pgPubsub := &pgPubsub{
|
pgPubsub := &pgPubsub{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
listenDone: make(chan struct{}),
|
||||||
db: database,
|
db: database,
|
||||||
pgListener: listener,
|
pgListener: listener,
|
||||||
queues: make(map[string]map[uuid.UUID]*msgQueue),
|
queues: make(map[string]map[uuid.UUID]*msgQueue),
|
||||||
}
|
}
|
||||||
go pgPubsub.listen(ctx)
|
go pgPubsub.listen()
|
||||||
|
|
||||||
return pgPubsub, nil
|
return pgPubsub, nil
|
||||||
}
|
}
|
||||||
|
@ -45,11 +45,11 @@ func TestPubsub(t *testing.T) {
|
|||||||
event := "test"
|
event := "test"
|
||||||
data := "testing"
|
data := "testing"
|
||||||
messageChannel := make(chan []byte)
|
messageChannel := make(chan []byte)
|
||||||
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||||
messageChannel <- message
|
messageChannel <- message
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer cancelFunc()
|
defer unsub()
|
||||||
go func() {
|
go func() {
|
||||||
err = pubsub.Publish(event, []byte(data))
|
err = pubsub.Publish(event, []byte(data))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -72,6 +72,91 @@ func TestPubsub(t *testing.T) {
|
|||||||
defer pubsub.Close()
|
defer pubsub.Close()
|
||||||
cancelFunc()
|
cancelFunc()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("NotClosedOnCancelContext", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
connectionURL, closePg, err := postgres.Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closePg()
|
||||||
|
db, err := sql.Open("postgres", connectionURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pubsub.Close()
|
||||||
|
|
||||||
|
// Provided context must only be active during NewPubsub, not after.
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
event := "test"
|
||||||
|
data := "testing"
|
||||||
|
messageChannel := make(chan []byte)
|
||||||
|
unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) {
|
||||||
|
messageChannel <- message
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer unsub()
|
||||||
|
go func() {
|
||||||
|
err = pubsub.Publish(event, []byte(data))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
message := <-messageChannel
|
||||||
|
assert.Equal(t, string(message), data)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||||
|
defer cancel()
|
||||||
|
connectionURL, closePg, err := postgres.Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closePg()
|
||||||
|
db, err := sql.Open("postgres", connectionURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pubsub.Close()
|
||||||
|
|
||||||
|
event := "test"
|
||||||
|
done := make(chan struct{})
|
||||||
|
called := make(chan struct{})
|
||||||
|
unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) {
|
||||||
|
defer close(done)
|
||||||
|
select {
|
||||||
|
case <-subCtx.Done():
|
||||||
|
assert.Fail(t, "context should not be canceled")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
close(called)
|
||||||
|
select {
|
||||||
|
case <-subCtx.Done():
|
||||||
|
case <-ctx.Done():
|
||||||
|
assert.Fail(t, "timeout waiting for sub context to be canceled")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer unsub()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := pubsub.Publish(event, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-called:
|
||||||
|
case <-ctx.Done():
|
||||||
|
require.Fail(t, "timeout waiting for handler to be called")
|
||||||
|
}
|
||||||
|
err = pubsub.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
require.Fail(t, "timeout waiting for handler to finish")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPubsub_ordering(t *testing.T) {
|
func TestPubsub_ordering(t *testing.T) {
|
||||||
|
Reference in New Issue
Block a user