chore: separate pubsub into a new package (#8017)

* chore: rename store to dbmock for consistency

* chore: remove redundant dbtype package

This wasn't necessary and forked how we do DB types.

* chore: separate pubsub into a new package

This didn't need to be in database and was bloating it.
This commit is contained in:
Kyle Carberry
2023-06-14 10:34:54 -05:00
committed by GitHub
parent 2c843f4011
commit e4b6f5695b
17 changed files with 95 additions and 85 deletions

View File

@ -68,6 +68,7 @@ import (
"github.com/coder/coder/coderd/database/dbmetrics"
"github.com/coder/coder/coderd/database/dbpurge"
"github.com/coder/coder/coderd/database/migrations"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/coderd/devtunnel"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
@ -463,7 +464,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
Logger: logger.Named("coderd"),
Database: dbfake.New(),
DERPMap: derpMap,
Pubsub: database.NewPubsubInMemory(),
Pubsub: pubsub.NewInMemory(),
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
GitAuthConfigs: gitAuthConfigs,
@ -589,7 +590,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
if cfg.InMemoryDatabase {
// This is only used for testing.
options.Database = dbmetrics.New(dbfake.New(), options.PrometheusRegistry)
options.Pubsub = database.NewPubsubInMemory()
options.Pubsub = pubsub.NewInMemory()
} else {
sqlDB, err := connectToPostgres(ctx, logger, sqlDriver, cfg.PostgresURL.String())
if err != nil {
@ -600,7 +601,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}()
options.Database = dbmetrics.New(database.New(sqlDB), options.PrometheusRegistry)
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, cfg.PostgresURL.String())
options.Pubsub, err = pubsub.New(ctx, sqlDB, cfg.PostgresURL.String())
if err != nil {
return xerrors.Errorf("create pubsub: %w", err)
}

View File

@ -48,6 +48,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbmetrics"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/healthcheck"
@ -95,7 +96,7 @@ type Options struct {
AppHostnameRegex *regexp.Regexp
Logger slog.Logger
Database database.Store
Pubsub database.Pubsub
Pubsub pubsub.Pubsub
// CacheDir is used for caching files served by the API.
CacheDir string

View File

@ -59,6 +59,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/healthcheck"
@ -130,7 +131,7 @@ type Options struct {
// It should only be used in cases where multiple Coder
// test instances are running against the same database.
Database database.Store
Pubsub database.Pubsub
Pubsub pubsub.Pubsub
ConfigSSH codersdk.SSHConfigResponse

View File

@ -11,13 +11,14 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/coderd/database/pubsub"
)
func NewDB(t testing.TB) (database.Store, database.Pubsub) {
func NewDB(t testing.TB) (database.Store, pubsub.Pubsub) {
t.Helper()
db := dbfake.New()
pubsub := database.NewPubsubInMemory()
ps := pubsub.NewInMemory()
if os.Getenv("DB") != "" {
connectionURL := os.Getenv("CODER_PG_CONNECTION_URL")
if connectionURL == "" {
@ -36,12 +37,12 @@ func NewDB(t testing.TB) (database.Store, database.Pubsub) {
})
db = database.New(sqlDB)
pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL)
ps, err = pubsub.New(context.Background(), sqlDB, connectionURL)
require.NoError(t, err)
t.Cleanup(func() {
_ = pubsub.Close()
_ = ps.Close()
})
}
return db, pubsub
return db, ps
}

View File

@ -1,4 +1,4 @@
package database
package pubsub
import (
"context"
@ -48,7 +48,7 @@ type msgOrErr struct {
type msgQueue struct {
ctx context.Context
cond *sync.Cond
q [PubsubBufferSize]msgOrErr
q [BufferSize]msgOrErr
front int
size int
closed bool
@ -82,7 +82,7 @@ func (q *msgQueue) run() {
return
}
item := q.q[q.front]
q.front = (q.front + 1) % PubsubBufferSize
q.front = (q.front + 1) % BufferSize
q.size--
q.cond.L.Unlock()
@ -111,20 +111,20 @@ func (q *msgQueue) enqueue(msg []byte) {
q.cond.L.Lock()
defer q.cond.L.Unlock()
if q.size == PubsubBufferSize {
if q.size == BufferSize {
// queue is full, so we're going to drop the msg we got called with.
// We also need to record that messages are being dropped, which we
// do at the last message in the queue. This potentially makes us
// lose 2 messages instead of one, but it's more important at this
// point to warn the subscriber that they're losing messages so they
// can do something about it.
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
back := (q.front + BufferSize - 1) % BufferSize
q.q[back].msg = nil
q.q[back].err = ErrDroppedMessages
return
}
// queue is not full, insert the message
next := (q.front + q.size) % PubsubBufferSize
next := (q.front + q.size) % BufferSize
q.q[next].msg = msg
q.q[next].err = nil
q.size++
@ -143,17 +143,17 @@ func (q *msgQueue) dropped() {
q.cond.L.Lock()
defer q.cond.L.Unlock()
if q.size == PubsubBufferSize {
if q.size == BufferSize {
// queue is full, but we need to record that messages are being dropped,
// which we do at the last message in the queue. This potentially drops
// another message, but it's more important for the subscriber to know.
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
back := (q.front + BufferSize - 1) % BufferSize
q.q[back].msg = nil
q.q[back].err = ErrDroppedMessages
return
}
// queue is not full, insert the error
next := (q.front + q.size) % PubsubBufferSize
next := (q.front + q.size) % BufferSize
q.q[next].msg = nil
q.q[next].err = ErrDroppedMessages
q.size++
@ -171,9 +171,9 @@ type pgPubsub struct {
queues map[string]map[uuid.UUID]*msgQueue
}
// PubsubBufferSize is the maximum number of unhandled messages we will buffer
// BufferSize is the maximum number of unhandled messages we will buffer
// for a subscriber before dropping messages.
const PubsubBufferSize = 2048
const BufferSize = 2048
// Subscribe calls the listener when an event matching the name is received.
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
@ -295,8 +295,8 @@ func (p *pgPubsub) recordReconnect() {
}
}
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// New creates a new Pubsub implementation using a PostgreSQL connection.
func New(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) {

View File

@ -1,4 +1,4 @@
package database
package pubsub
import (
"context"
@ -26,7 +26,7 @@ func Test_msgQueue_ListenerWithError(t *testing.T) {
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
cycles := (BufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
@ -75,7 +75,7 @@ func Test_msgQueue_Listener(t *testing.T) {
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
// when we wrap around the end of the circular buffer. This tests that we correctly handle
// the wrapping and aren't dequeueing misaligned data.
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
cycles := (BufferSize / 5) * 2 // almost twice around the ring
for j := 0; j < cycles; j++ {
for i := 0; i < 4; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
@ -119,7 +119,7 @@ func Test_msgQueue_Full(t *testing.T) {
// we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks
// but only after we've dequeued a message, and then another extra because we want to exceed
// the capacity, not just reach it.
for i := 0; i < PubsubBufferSize+2; i++ {
for i := 0; i < BufferSize+2; i++ {
uut.enqueue([]byte(fmt.Sprintf("%d", i)))
// ensure the first dequeue has happened before proceeding, so that this function isn't racing
// against the goroutine that dequeues items.
@ -136,5 +136,5 @@ func Test_msgQueue_Full(t *testing.T) {
// Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last
// message we send doesn't get queued, AND, it bumps a message out of the queue to make room
// for the error, so we read 2 less than we sent.
require.Equal(t, PubsubBufferSize, n)
require.Equal(t, BufferSize, n)
}

View File

@ -1,4 +1,4 @@
package database
package pubsub
import (
"context"
@ -87,7 +87,7 @@ func (*memoryPubsub) Close() error {
return nil
}
func NewPubsubInMemory() Pubsub {
func NewInMemory() Pubsub {
return &memoryPubsub{
listeners: make(map[string]map[uuid.UUID]genericListener),
}

View File

@ -1,4 +1,4 @@
package database_test
package pubsub_test
import (
"context"
@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/pubsub"
)
func TestPubsubMemory(t *testing.T) {
@ -16,7 +16,7 @@ func TestPubsubMemory(t *testing.T) {
t.Run("Legacy", func(t *testing.T) {
t.Parallel()
pubsub := database.NewPubsubInMemory()
pubsub := pubsub.NewInMemory()
event := "test"
data := "testing"
messageChannel := make(chan []byte)
@ -36,7 +36,7 @@ func TestPubsubMemory(t *testing.T) {
t.Run("WithErr", func(t *testing.T) {
t.Parallel()
pubsub := database.NewPubsubInMemory()
pubsub := pubsub.NewInMemory()
event := "test"
data := "testing"
messageChannel := make(chan []byte)

View File

@ -1,6 +1,6 @@
//go:build linux
package database_test
package pubsub_test
import (
"context"
@ -15,8 +15,8 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/testutil"
)
@ -39,7 +39,7 @@ func TestPubsub(t *testing.T) {
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
pubsub, err := pubsub.New(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
event := "test"
@ -67,7 +67,7 @@ func TestPubsub(t *testing.T) {
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
pubsub, err := pubsub.New(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
cancelFunc()
@ -82,7 +82,7 @@ func TestPubsub(t *testing.T) {
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
pubsub, err := pubsub.New(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
@ -114,7 +114,7 @@ func TestPubsub(t *testing.T) {
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
pubsub, err := pubsub.New(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
@ -171,12 +171,12 @@ func TestPubsub_ordering(t *testing.T) {
db, err := sql.Open("postgres", connectionURL)
require.NoError(t, err)
defer db.Close()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
ps, err := pubsub.New(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
defer ps.Close()
event := "test"
messageChannel := make(chan []byte, 100)
cancelSub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
cancelSub, err := ps.Subscribe(event, func(ctx context.Context, message []byte) {
// sleep a random amount of time to simulate handlers taking different amount of time
// to process, depending on the message
// nolint: gosec
@ -187,7 +187,7 @@ func TestPubsub_ordering(t *testing.T) {
require.NoError(t, err)
defer cancelSub()
for i := 0; i < 100; i++ {
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
err = ps.Publish(event, []byte(fmt.Sprintf("%d", i)))
assert.NoError(t, err)
}
for i := 0; i < 100; i++ {
@ -219,14 +219,14 @@ func TestPubsub_Disconnect(t *testing.T) {
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
ps, err := pubsub.New(ctx, db, connectionURL)
require.NoError(t, err)
defer pubsub.Close()
defer ps.Close()
event := "test"
// buffer responses so that when the test completes, goroutines don't get blocked & leak
errors := make(chan error, database.PubsubBufferSize)
messages := make(chan string, database.PubsubBufferSize)
errors := make(chan error, pubsub.BufferSize)
messages := make(chan string, pubsub.BufferSize)
readOne := func() (m string, e error) {
t.Helper()
select {
@ -244,7 +244,7 @@ func TestPubsub_Disconnect(t *testing.T) {
return m, e
}
cancelSub, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
cancelSub, err := ps.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
messages <- string(msg)
errors <- err
})
@ -252,7 +252,7 @@ func TestPubsub_Disconnect(t *testing.T) {
defer cancelSub()
for i := 0; i < 100; i++ {
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
err = ps.Publish(event, []byte(fmt.Sprintf("%d", i)))
require.NoError(t, err)
}
// make sure we're getting at least one message.
@ -270,7 +270,7 @@ func TestPubsub_Disconnect(t *testing.T) {
default:
// ok
}
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
err = ps.Publish(event, []byte(fmt.Sprintf("%d", j)))
j++
if err != nil {
break
@ -292,7 +292,7 @@ func TestPubsub_Disconnect(t *testing.T) {
default:
// ok
}
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
err = ps.Publish(event, []byte(fmt.Sprintf("%d", j)))
if err == nil {
break
}
@ -303,7 +303,7 @@ func TestPubsub_Disconnect(t *testing.T) {
k := j
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB
// reconnect
require.Less(t, k, database.PubsubBufferSize, "exceeded buffer")
require.Less(t, k, pubsub.BufferSize, "exceeded buffer")
// We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As
// soon as we see k or higher we know we're getting messages after the restart.
@ -315,7 +315,7 @@ func TestPubsub_Disconnect(t *testing.T) {
default:
// ok
}
_ = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
_ = ps.Publish(event, []byte(fmt.Sprintf("%d", j)))
j++
time.Sleep(testutil.IntervalFast)
}
@ -324,7 +324,7 @@ func TestPubsub_Disconnect(t *testing.T) {
gotDroppedErr := false
for {
m, err := readOne()
if xerrors.Is(err, database.ErrDroppedMessages) {
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
gotDroppedErr = true
continue
}
@ -334,7 +334,7 @@ func TestPubsub_Disconnect(t *testing.T) {
if l >= k {
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than
// DB reconnect
require.Less(t, l, database.PubsubBufferSize, "exceeded buffer")
require.Less(t, l, pubsub.BufferSize, "exceeded buffer")
break
}
}

View File

@ -31,6 +31,7 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/schedule"
@ -56,7 +57,7 @@ type Server struct {
GitAuthConfigs []*gitauth.Config
Tags json.RawMessage
Database database.Store
Pubsub database.Pubsub
Pubsub pubsub.Pubsub
Telemetry telemetry.Reporter
Tracer trace.Tracer
QuotaCommitter *atomic.Pointer[proto.QuotaCommitter]

View File

@ -21,6 +21,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbgen"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/schedule"
@ -51,14 +52,14 @@ func TestAcquireJob(t *testing.T) {
t.Run("Debounce", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
pubsub := database.NewPubsubInMemory()
ps := pubsub.NewInMemory()
srv := &provisionerdserver.Server{
ID: uuid.New(),
Logger: slogtest.Make(t, nil),
AccessURL: &url.URL{},
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
Database: db,
Pubsub: pubsub,
Pubsub: ps,
Telemetry: telemetry.NewNoop(),
AcquireJobDebounce: time.Hour,
Auditor: mockAuditor(),
@ -1256,7 +1257,7 @@ func TestInsertWorkspaceResource(t *testing.T) {
func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server {
t.Helper()
db := dbfake.New()
pubsub := database.NewPubsubInMemory()
ps := pubsub.NewInMemory()
return &provisionerdserver.Server{
ID: uuid.New(),
@ -1265,7 +1266,7 @@ func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server {
AccessURL: &url.URL{},
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
Database: db,
Pubsub: pubsub,
Pubsub: ps,
Telemetry: telemetry.NewNoop(),
Auditor: mockAuditor(),
TemplateScheduleStore: testTemplateScheduleStore(),

View File

@ -19,6 +19,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/db2sdk"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
@ -268,7 +269,7 @@ type logFollower struct {
ctx context.Context
logger slog.Logger
db database.Store
pubsub database.Pubsub
pubsub pubsub.Pubsub
r *http.Request
rw http.ResponseWriter
conn *websocket.Conn
@ -281,14 +282,14 @@ type logFollower struct {
}
func newLogFollower(
ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub,
ctx context.Context, logger slog.Logger, db database.Store, ps pubsub.Pubsub,
rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob, after int64,
) *logFollower {
return &logFollower{
ctx: ctx,
logger: logger,
db: db,
pubsub: pubsub,
pubsub: ps,
r: r,
rw: rw,
jobID: job.ID,

View File

@ -20,6 +20,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbmock"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/testutil"
@ -138,7 +139,7 @@ func Test_logFollower_completeBeforeFollow(t *testing.T) {
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
pubsub := database.NewPubsubInMemory()
ps := pubsub.NewInMemory()
now := database.Now()
job := database.ProvisionerJob{
ID: uuid.New(),
@ -157,7 +158,7 @@ func Test_logFollower_completeBeforeFollow(t *testing.T) {
// we need an HTTP server to get a websocket
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 10)
uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 10)
uut.follow()
}))
defer srv.Close()
@ -200,7 +201,7 @@ func Test_logFollower_completeBeforeSubscribe(t *testing.T) {
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
pubsub := database.NewPubsubInMemory()
ps := pubsub.NewInMemory()
now := database.Now()
job := database.ProvisionerJob{
ID: uuid.New(),
@ -217,7 +218,7 @@ func Test_logFollower_completeBeforeSubscribe(t *testing.T) {
// we need an HTTP server to get a websocket
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0)
uut.follow()
}))
defer srv.Close()
@ -276,7 +277,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
pubsub := database.NewPubsubInMemory()
ps := pubsub.NewInMemory()
now := database.Now()
job := database.ProvisionerJob{
ID: uuid.New(),
@ -293,7 +294,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
// we need an HTTP server to get a websocket
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0)
uut.follow()
}))
defer srv.Close()
@ -342,7 +343,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
}
msg, err = json.Marshal(&n)
require.NoError(t, err)
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
err = ps.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
require.NoError(t, err)
mt, msg, err = client.Read(ctx)
@ -360,7 +361,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
n.CreatedAfter = 0
msg, err = json.Marshal(&n)
require.NoError(t, err)
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
err = ps.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
require.NoError(t, err)
// server should now close

View File

@ -20,6 +20,7 @@ import (
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/database/pubsub"
)
var PubsubEvent = "replica"
@ -36,7 +37,7 @@ type Options struct {
// New registers the replica with the database and periodically updates to ensure
// it's healthy. It contacts all other alive replicas to ensure they are reachable.
func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, options *Options) (*Manager, error) {
func New(ctx context.Context, logger slog.Logger, db database.Store, ps pubsub.Pubsub, options *Options) (*Manager, error) {
if options == nil {
options = &Options{}
}
@ -77,7 +78,7 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub data
if err != nil {
return nil, xerrors.Errorf("insert replica: %w", err)
}
err = pubsub.Publish(PubsubEvent, []byte(options.ID.String()))
err = ps.Publish(PubsubEvent, []byte(options.ID.String()))
if err != nil {
return nil, xerrors.Errorf("publish new replica: %w", err)
}
@ -86,7 +87,7 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub data
id: options.ID,
options: options,
db: db,
pubsub: pubsub,
pubsub: ps,
self: replica,
logger: logger,
closed: make(chan struct{}),
@ -110,7 +111,7 @@ type Manager struct {
id uuid.UUID
options *Options
db database.Store
pubsub database.Pubsub
pubsub pubsub.Pubsub
logger slog.Logger
closeWait sync.WaitGroup

View File

@ -18,6 +18,7 @@ import (
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbfake"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/enterprise/replicasync"
"github.com/coder/coder/testutil"
)
@ -212,7 +213,7 @@ func TestReplica(t *testing.T) {
// this many PostgreSQL connections takes some
// configuration tweaking.
db := dbfake.New()
pubsub := database.NewPubsubInMemory()
pubsub := pubsub.NewInMemory()
logger := slogtest.Make(t, nil)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

View File

@ -16,13 +16,13 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/pubsub"
agpl "github.com/coder/coder/tailnet"
)
// NewCoordinator creates a new high availability coordinator
// that uses PostgreSQL pubsub to exchange handshakes.
func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) {
func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, error) {
ctx, cancelFunc := context.WithCancel(context.Background())
nameCache, err := lru.New[uuid.UUID, string](512)
@ -33,7 +33,7 @@ func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinato
coord := &haCoordinator{
id: uuid.New(),
log: logger,
pubsub: pubsub,
pubsub: ps,
closeFunc: cancelFunc,
close: make(chan struct{}),
nodes: map[uuid.UUID]*agpl.Node{},
@ -53,7 +53,7 @@ type haCoordinator struct {
id uuid.UUID
log slog.Logger
mutex sync.RWMutex
pubsub database.Pubsub
pubsub pubsub.Pubsub
close chan struct{}
closeFunc context.CancelFunc

View File

@ -10,8 +10,8 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/coderd/database/pubsub"
"github.com/coder/coder/enterprise/tailnet"
agpl "github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
@ -21,7 +21,7 @@ func TestCoordinatorSingle(t *testing.T) {
t.Parallel()
t.Run("ClientWithoutAgent", func(t *testing.T) {
t.Parallel()
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory())
require.NoError(t, err)
defer coordinator.Close()
@ -49,7 +49,7 @@ func TestCoordinatorSingle(t *testing.T) {
t.Run("AgentWithoutClients", func(t *testing.T) {
t.Parallel()
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory())
require.NoError(t, err)
defer coordinator.Close()
@ -77,7 +77,7 @@ func TestCoordinatorSingle(t *testing.T) {
t.Run("AgentWithClient", func(t *testing.T) {
t.Parallel()
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory())
require.NoError(t, err)
defer coordinator.Close()