mirror of
https://github.com/coder/coder.git
synced 2025-07-23 21:32:07 +00:00
chore: separate pubsub into a new package (#8017)
* chore: rename store to dbmock for consistency * chore: remove redundant dbtype package This wasn't necessary and forked how we do DB types. * chore: separate pubsub into a new package This didn't need to be in database and was bloating it.
This commit is contained in:
@ -68,6 +68,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database/dbmetrics"
|
||||
"github.com/coder/coder/coderd/database/dbpurge"
|
||||
"github.com/coder/coder/coderd/database/migrations"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/coderd/devtunnel"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
@ -463,7 +464,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
Logger: logger.Named("coderd"),
|
||||
Database: dbfake.New(),
|
||||
DERPMap: derpMap,
|
||||
Pubsub: database.NewPubsubInMemory(),
|
||||
Pubsub: pubsub.NewInMemory(),
|
||||
CacheDir: cacheDir,
|
||||
GoogleTokenValidator: googleTokenValidator,
|
||||
GitAuthConfigs: gitAuthConfigs,
|
||||
@ -589,7 +590,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
if cfg.InMemoryDatabase {
|
||||
// This is only used for testing.
|
||||
options.Database = dbmetrics.New(dbfake.New(), options.PrometheusRegistry)
|
||||
options.Pubsub = database.NewPubsubInMemory()
|
||||
options.Pubsub = pubsub.NewInMemory()
|
||||
} else {
|
||||
sqlDB, err := connectToPostgres(ctx, logger, sqlDriver, cfg.PostgresURL.String())
|
||||
if err != nil {
|
||||
@ -600,7 +601,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
}()
|
||||
|
||||
options.Database = dbmetrics.New(database.New(sqlDB), options.PrometheusRegistry)
|
||||
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, cfg.PostgresURL.String())
|
||||
options.Pubsub, err = pubsub.New(ctx, sqlDB, cfg.PostgresURL.String())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create pubsub: %w", err)
|
||||
}
|
||||
|
@ -48,6 +48,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/dbmetrics"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
"github.com/coder/coder/coderd/healthcheck"
|
||||
@ -95,7 +96,7 @@ type Options struct {
|
||||
AppHostnameRegex *regexp.Regexp
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
Pubsub pubsub.Pubsub
|
||||
|
||||
// CacheDir is used for caching files served by the API.
|
||||
CacheDir string
|
||||
|
@ -59,6 +59,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
"github.com/coder/coder/coderd/healthcheck"
|
||||
@ -130,7 +131,7 @@ type Options struct {
|
||||
// It should only be used in cases where multiple Coder
|
||||
// test instances are running against the same database.
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
Pubsub pubsub.Pubsub
|
||||
|
||||
ConfigSSH codersdk.SSHConfigResponse
|
||||
|
||||
|
@ -11,13 +11,14 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
func NewDB(t testing.TB) (database.Store, database.Pubsub) {
|
||||
func NewDB(t testing.TB) (database.Store, pubsub.Pubsub) {
|
||||
t.Helper()
|
||||
|
||||
db := dbfake.New()
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
ps := pubsub.NewInMemory()
|
||||
if os.Getenv("DB") != "" {
|
||||
connectionURL := os.Getenv("CODER_PG_CONNECTION_URL")
|
||||
if connectionURL == "" {
|
||||
@ -36,12 +37,12 @@ func NewDB(t testing.TB) (database.Store, database.Pubsub) {
|
||||
})
|
||||
db = database.New(sqlDB)
|
||||
|
||||
pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL)
|
||||
ps, err = pubsub.New(context.Background(), sqlDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = pubsub.Close()
|
||||
_ = ps.Close()
|
||||
})
|
||||
}
|
||||
|
||||
return db, pubsub
|
||||
return db, ps
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package database
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -48,7 +48,7 @@ type msgOrErr struct {
|
||||
type msgQueue struct {
|
||||
ctx context.Context
|
||||
cond *sync.Cond
|
||||
q [PubsubBufferSize]msgOrErr
|
||||
q [BufferSize]msgOrErr
|
||||
front int
|
||||
size int
|
||||
closed bool
|
||||
@ -82,7 +82,7 @@ func (q *msgQueue) run() {
|
||||
return
|
||||
}
|
||||
item := q.q[q.front]
|
||||
q.front = (q.front + 1) % PubsubBufferSize
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
q.size--
|
||||
q.cond.L.Unlock()
|
||||
|
||||
@ -111,20 +111,20 @@ func (q *msgQueue) enqueue(msg []byte) {
|
||||
q.cond.L.Lock()
|
||||
defer q.cond.L.Unlock()
|
||||
|
||||
if q.size == PubsubBufferSize {
|
||||
if q.size == BufferSize {
|
||||
// queue is full, so we're going to drop the msg we got called with.
|
||||
// We also need to record that messages are being dropped, which we
|
||||
// do at the last message in the queue. This potentially makes us
|
||||
// lose 2 messages instead of one, but it's more important at this
|
||||
// point to warn the subscriber that they're losing messages so they
|
||||
// can do something about it.
|
||||
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
|
||||
back := (q.front + BufferSize - 1) % BufferSize
|
||||
q.q[back].msg = nil
|
||||
q.q[back].err = ErrDroppedMessages
|
||||
return
|
||||
}
|
||||
// queue is not full, insert the message
|
||||
next := (q.front + q.size) % PubsubBufferSize
|
||||
next := (q.front + q.size) % BufferSize
|
||||
q.q[next].msg = msg
|
||||
q.q[next].err = nil
|
||||
q.size++
|
||||
@ -143,17 +143,17 @@ func (q *msgQueue) dropped() {
|
||||
q.cond.L.Lock()
|
||||
defer q.cond.L.Unlock()
|
||||
|
||||
if q.size == PubsubBufferSize {
|
||||
if q.size == BufferSize {
|
||||
// queue is full, but we need to record that messages are being dropped,
|
||||
// which we do at the last message in the queue. This potentially drops
|
||||
// another message, but it's more important for the subscriber to know.
|
||||
back := (q.front + PubsubBufferSize - 1) % PubsubBufferSize
|
||||
back := (q.front + BufferSize - 1) % BufferSize
|
||||
q.q[back].msg = nil
|
||||
q.q[back].err = ErrDroppedMessages
|
||||
return
|
||||
}
|
||||
// queue is not full, insert the error
|
||||
next := (q.front + q.size) % PubsubBufferSize
|
||||
next := (q.front + q.size) % BufferSize
|
||||
q.q[next].msg = nil
|
||||
q.q[next].err = ErrDroppedMessages
|
||||
q.size++
|
||||
@ -171,9 +171,9 @@ type pgPubsub struct {
|
||||
queues map[string]map[uuid.UUID]*msgQueue
|
||||
}
|
||||
|
||||
// PubsubBufferSize is the maximum number of unhandled messages we will buffer
|
||||
// BufferSize is the maximum number of unhandled messages we will buffer
|
||||
// for a subscriber before dropping messages.
|
||||
const PubsubBufferSize = 2048
|
||||
const BufferSize = 2048
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
@ -295,8 +295,8 @@ func (p *pgPubsub) recordReconnect() {
|
||||
}
|
||||
}
|
||||
|
||||
// NewPubsub creates a new Pubsub implementation using a PostgreSQL connection.
|
||||
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
|
||||
// New creates a new Pubsub implementation using a PostgreSQL connection.
|
||||
func New(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
|
||||
// Creates a new listener using pq.
|
||||
errCh := make(chan error)
|
||||
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) {
|
@ -1,4 +1,4 @@
|
||||
package database
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -26,7 +26,7 @@ func Test_msgQueue_ListenerWithError(t *testing.T) {
|
||||
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
|
||||
// when we wrap around the end of the circular buffer. This tests that we correctly handle
|
||||
// the wrapping and aren't dequeueing misaligned data.
|
||||
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
|
||||
cycles := (BufferSize / 5) * 2 // almost twice around the ring
|
||||
for j := 0; j < cycles; j++ {
|
||||
for i := 0; i < 4; i++ {
|
||||
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
|
||||
@ -75,7 +75,7 @@ func Test_msgQueue_Listener(t *testing.T) {
|
||||
// PubsubBufferSize is 2048, which is a power of 2, so a pattern of 5 will not be aligned
|
||||
// when we wrap around the end of the circular buffer. This tests that we correctly handle
|
||||
// the wrapping and aren't dequeueing misaligned data.
|
||||
cycles := (PubsubBufferSize / 5) * 2 // almost twice around the ring
|
||||
cycles := (BufferSize / 5) * 2 // almost twice around the ring
|
||||
for j := 0; j < cycles; j++ {
|
||||
for i := 0; i < 4; i++ {
|
||||
uut.enqueue([]byte(fmt.Sprintf("%d%d", j, i)))
|
||||
@ -119,7 +119,7 @@ func Test_msgQueue_Full(t *testing.T) {
|
||||
// we send 2 more than the capacity. One extra because the call to the ListenerFunc blocks
|
||||
// but only after we've dequeued a message, and then another extra because we want to exceed
|
||||
// the capacity, not just reach it.
|
||||
for i := 0; i < PubsubBufferSize+2; i++ {
|
||||
for i := 0; i < BufferSize+2; i++ {
|
||||
uut.enqueue([]byte(fmt.Sprintf("%d", i)))
|
||||
// ensure the first dequeue has happened before proceeding, so that this function isn't racing
|
||||
// against the goroutine that dequeues items.
|
||||
@ -136,5 +136,5 @@ func Test_msgQueue_Full(t *testing.T) {
|
||||
// Ok, so we sent 2 more than capacity, but we only read the capacity, that's because the last
|
||||
// message we send doesn't get queued, AND, it bumps a message out of the queue to make room
|
||||
// for the error, so we read 2 less than we sent.
|
||||
require.Equal(t, PubsubBufferSize, n)
|
||||
require.Equal(t, BufferSize, n)
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package database
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -87,7 +87,7 @@ func (*memoryPubsub) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPubsubInMemory() Pubsub {
|
||||
func NewInMemory() Pubsub {
|
||||
return &memoryPubsub{
|
||||
listeners: make(map[string]map[uuid.UUID]genericListener),
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package database_test
|
||||
package pubsub_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
func TestPubsubMemory(t *testing.T) {
|
||||
@ -16,7 +16,7 @@ func TestPubsubMemory(t *testing.T) {
|
||||
t.Run("Legacy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
pubsub := pubsub.NewInMemory()
|
||||
event := "test"
|
||||
data := "testing"
|
||||
messageChannel := make(chan []byte)
|
||||
@ -36,7 +36,7 @@ func TestPubsubMemory(t *testing.T) {
|
||||
t.Run("WithErr", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
pubsub := pubsub.NewInMemory()
|
||||
event := "test"
|
||||
data := "testing"
|
||||
messageChannel := make(chan []byte)
|
@ -1,6 +1,6 @@
|
||||
//go:build linux
|
||||
|
||||
package database_test
|
||||
package pubsub_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -15,8 +15,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
@ -39,7 +39,7 @@ func TestPubsub(t *testing.T) {
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
pubsub, err := pubsub.New(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
event := "test"
|
||||
@ -67,7 +67,7 @@ func TestPubsub(t *testing.T) {
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
pubsub, err := pubsub.New(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
cancelFunc()
|
||||
@ -82,7 +82,7 @@ func TestPubsub(t *testing.T) {
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
pubsub, err := pubsub.New(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
|
||||
@ -114,7 +114,7 @@ func TestPubsub(t *testing.T) {
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
pubsub, err := pubsub.New(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
|
||||
@ -171,12 +171,12 @@ func TestPubsub_ordering(t *testing.T) {
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
ps, err := pubsub.New(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
defer ps.Close()
|
||||
event := "test"
|
||||
messageChannel := make(chan []byte, 100)
|
||||
cancelSub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
cancelSub, err := ps.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
// sleep a random amount of time to simulate handlers taking different amount of time
|
||||
// to process, depending on the message
|
||||
// nolint: gosec
|
||||
@ -187,7 +187,7 @@ func TestPubsub_ordering(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer cancelSub()
|
||||
for i := 0; i < 100; i++ {
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
|
||||
err = ps.Publish(event, []byte(fmt.Sprintf("%d", i)))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for i := 0; i < 100; i++ {
|
||||
@ -219,14 +219,14 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
|
||||
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
|
||||
defer cancelFunc()
|
||||
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
|
||||
ps, err := pubsub.New(ctx, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubsub.Close()
|
||||
defer ps.Close()
|
||||
event := "test"
|
||||
|
||||
// buffer responses so that when the test completes, goroutines don't get blocked & leak
|
||||
errors := make(chan error, database.PubsubBufferSize)
|
||||
messages := make(chan string, database.PubsubBufferSize)
|
||||
errors := make(chan error, pubsub.BufferSize)
|
||||
messages := make(chan string, pubsub.BufferSize)
|
||||
readOne := func() (m string, e error) {
|
||||
t.Helper()
|
||||
select {
|
||||
@ -244,7 +244,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
return m, e
|
||||
}
|
||||
|
||||
cancelSub, err := pubsub.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
|
||||
cancelSub, err := ps.SubscribeWithErr(event, func(ctx context.Context, msg []byte, err error) {
|
||||
messages <- string(msg)
|
||||
errors <- err
|
||||
})
|
||||
@ -252,7 +252,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
defer cancelSub()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", i)))
|
||||
err = ps.Publish(event, []byte(fmt.Sprintf("%d", i)))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
// make sure we're getting at least one message.
|
||||
@ -270,7 +270,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
err = ps.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
j++
|
||||
if err != nil {
|
||||
break
|
||||
@ -292,7 +292,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
err = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
err = ps.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
@ -303,7 +303,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
k := j
|
||||
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than DB
|
||||
// reconnect
|
||||
require.Less(t, k, database.PubsubBufferSize, "exceeded buffer")
|
||||
require.Less(t, k, pubsub.BufferSize, "exceeded buffer")
|
||||
|
||||
// We don't know how quickly the pubsub will reconnect, so continue to send messages with increasing numbers. As
|
||||
// soon as we see k or higher we know we're getting messages after the restart.
|
||||
@ -315,7 +315,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
default:
|
||||
// ok
|
||||
}
|
||||
_ = pubsub.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
_ = ps.Publish(event, []byte(fmt.Sprintf("%d", j)))
|
||||
j++
|
||||
time.Sleep(testutil.IntervalFast)
|
||||
}
|
||||
@ -324,7 +324,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
gotDroppedErr := false
|
||||
for {
|
||||
m, err := readOne()
|
||||
if xerrors.Is(err, database.ErrDroppedMessages) {
|
||||
if xerrors.Is(err, pubsub.ErrDroppedMessages) {
|
||||
gotDroppedErr = true
|
||||
continue
|
||||
}
|
||||
@ -334,7 +334,7 @@ func TestPubsub_Disconnect(t *testing.T) {
|
||||
if l >= k {
|
||||
// exceeding the buffer invalidates the test because this causes us to drop messages for reasons other than
|
||||
// DB reconnect
|
||||
require.Less(t, l, database.PubsubBufferSize, "exceeded buffer")
|
||||
require.Less(t, l, pubsub.BufferSize, "exceeded buffer")
|
||||
break
|
||||
}
|
||||
}
|
@ -31,6 +31,7 @@ import (
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/schedule"
|
||||
@ -56,7 +57,7 @@ type Server struct {
|
||||
GitAuthConfigs []*gitauth.Config
|
||||
Tags json.RawMessage
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
Pubsub pubsub.Pubsub
|
||||
Telemetry telemetry.Reporter
|
||||
Tracer trace.Tracer
|
||||
QuotaCommitter *atomic.Pointer[proto.QuotaCommitter]
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/database/dbgen"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/coderd/gitauth"
|
||||
"github.com/coder/coder/coderd/provisionerdserver"
|
||||
"github.com/coder/coder/coderd/schedule"
|
||||
@ -51,14 +52,14 @@ func TestAcquireJob(t *testing.T) {
|
||||
t.Run("Debounce", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := dbfake.New()
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
ps := pubsub.NewInMemory()
|
||||
srv := &provisionerdserver.Server{
|
||||
ID: uuid.New(),
|
||||
Logger: slogtest.Make(t, nil),
|
||||
AccessURL: &url.URL{},
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
Pubsub: ps,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
AcquireJobDebounce: time.Hour,
|
||||
Auditor: mockAuditor(),
|
||||
@ -1256,7 +1257,7 @@ func TestInsertWorkspaceResource(t *testing.T) {
|
||||
func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server {
|
||||
t.Helper()
|
||||
db := dbfake.New()
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
ps := pubsub.NewInMemory()
|
||||
|
||||
return &provisionerdserver.Server{
|
||||
ID: uuid.New(),
|
||||
@ -1265,7 +1266,7 @@ func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server {
|
||||
AccessURL: &url.URL{},
|
||||
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
Pubsub: ps,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
Auditor: mockAuditor(),
|
||||
TemplateScheduleStore: testTemplateScheduleStore(),
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
@ -268,7 +269,7 @@ type logFollower struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
db database.Store
|
||||
pubsub database.Pubsub
|
||||
pubsub pubsub.Pubsub
|
||||
r *http.Request
|
||||
rw http.ResponseWriter
|
||||
conn *websocket.Conn
|
||||
@ -281,14 +282,14 @@ type logFollower struct {
|
||||
}
|
||||
|
||||
func newLogFollower(
|
||||
ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub,
|
||||
ctx context.Context, logger slog.Logger, db database.Store, ps pubsub.Pubsub,
|
||||
rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob, after int64,
|
||||
) *logFollower {
|
||||
return &logFollower{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
db: db,
|
||||
pubsub: pubsub,
|
||||
pubsub: ps,
|
||||
r: r,
|
||||
rw: rw,
|
||||
jobID: job.ID,
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbmock"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
@ -138,7 +139,7 @@ func Test_logFollower_completeBeforeFollow(t *testing.T) {
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
ps := pubsub.NewInMemory()
|
||||
now := database.Now()
|
||||
job := database.ProvisionerJob{
|
||||
ID: uuid.New(),
|
||||
@ -157,7 +158,7 @@ func Test_logFollower_completeBeforeFollow(t *testing.T) {
|
||||
|
||||
// we need an HTTP server to get a websocket
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 10)
|
||||
uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 10)
|
||||
uut.follow()
|
||||
}))
|
||||
defer srv.Close()
|
||||
@ -200,7 +201,7 @@ func Test_logFollower_completeBeforeSubscribe(t *testing.T) {
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
ps := pubsub.NewInMemory()
|
||||
now := database.Now()
|
||||
job := database.ProvisionerJob{
|
||||
ID: uuid.New(),
|
||||
@ -217,7 +218,7 @@ func Test_logFollower_completeBeforeSubscribe(t *testing.T) {
|
||||
|
||||
// we need an HTTP server to get a websocket
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
|
||||
uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0)
|
||||
uut.follow()
|
||||
}))
|
||||
defer srv.Close()
|
||||
@ -276,7 +277,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
ps := pubsub.NewInMemory()
|
||||
now := database.Now()
|
||||
job := database.ProvisionerJob{
|
||||
ID: uuid.New(),
|
||||
@ -293,7 +294,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
|
||||
|
||||
// we need an HTTP server to get a websocket
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
uut := newLogFollower(ctx, logger, mDB, pubsub, rw, r, job, 0)
|
||||
uut := newLogFollower(ctx, logger, mDB, ps, rw, r, job, 0)
|
||||
uut.follow()
|
||||
}))
|
||||
defer srv.Close()
|
||||
@ -342,7 +343,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
|
||||
}
|
||||
msg, err = json.Marshal(&n)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
|
||||
err = ps.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
mt, msg, err = client.Read(ctx)
|
||||
@ -360,7 +361,7 @@ func Test_logFollower_EndOfLogs(t *testing.T) {
|
||||
n.CreatedAfter = 0
|
||||
msg, err = json.Marshal(&n)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
|
||||
err = ps.Publish(provisionersdk.ProvisionerJobLogsNotifyChannel(job.ID), msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// server should now close
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
var PubsubEvent = "replica"
|
||||
@ -36,7 +37,7 @@ type Options struct {
|
||||
|
||||
// New registers the replica with the database and periodically updates to ensure
|
||||
// it's healthy. It contacts all other alive replicas to ensure they are reachable.
|
||||
func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub database.Pubsub, options *Options) (*Manager, error) {
|
||||
func New(ctx context.Context, logger slog.Logger, db database.Store, ps pubsub.Pubsub, options *Options) (*Manager, error) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
@ -77,7 +78,7 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub data
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert replica: %w", err)
|
||||
}
|
||||
err = pubsub.Publish(PubsubEvent, []byte(options.ID.String()))
|
||||
err = ps.Publish(PubsubEvent, []byte(options.ID.String()))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("publish new replica: %w", err)
|
||||
}
|
||||
@ -86,7 +87,7 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, pubsub data
|
||||
id: options.ID,
|
||||
options: options,
|
||||
db: db,
|
||||
pubsub: pubsub,
|
||||
pubsub: ps,
|
||||
self: replica,
|
||||
logger: logger,
|
||||
closed: make(chan struct{}),
|
||||
@ -110,7 +111,7 @@ type Manager struct {
|
||||
id uuid.UUID
|
||||
options *Options
|
||||
db database.Store
|
||||
pubsub database.Pubsub
|
||||
pubsub pubsub.Pubsub
|
||||
logger slog.Logger
|
||||
|
||||
closeWait sync.WaitGroup
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/enterprise/replicasync"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
@ -212,7 +213,7 @@ func TestReplica(t *testing.T) {
|
||||
// this many PostgreSQL connections takes some
|
||||
// configuration tweaking.
|
||||
db := dbfake.New()
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
pubsub := pubsub.NewInMemory()
|
||||
logger := slogtest.Make(t, nil)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
@ -16,13 +16,13 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
agpl "github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
// NewCoordinator creates a new high availability coordinator
|
||||
// that uses PostgreSQL pubsub to exchange handshakes.
|
||||
func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinator, error) {
|
||||
func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, error) {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
nameCache, err := lru.New[uuid.UUID, string](512)
|
||||
@ -33,7 +33,7 @@ func NewCoordinator(logger slog.Logger, pubsub database.Pubsub) (agpl.Coordinato
|
||||
coord := &haCoordinator{
|
||||
id: uuid.New(),
|
||||
log: logger,
|
||||
pubsub: pubsub,
|
||||
pubsub: ps,
|
||||
closeFunc: cancelFunc,
|
||||
close: make(chan struct{}),
|
||||
nodes: map[uuid.UUID]*agpl.Node{},
|
||||
@ -53,7 +53,7 @@ type haCoordinator struct {
|
||||
id uuid.UUID
|
||||
log slog.Logger
|
||||
mutex sync.RWMutex
|
||||
pubsub database.Pubsub
|
||||
pubsub pubsub.Pubsub
|
||||
close chan struct{}
|
||||
closeFunc context.CancelFunc
|
||||
|
||||
|
@ -10,8 +10,8 @@ import (
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/coderd/database/pubsub"
|
||||
"github.com/coder/coder/enterprise/tailnet"
|
||||
agpl "github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/testutil"
|
||||
@ -21,7 +21,7 @@ func TestCoordinatorSingle(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("ClientWithoutAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory())
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
@ -49,7 +49,7 @@ func TestCoordinatorSingle(t *testing.T) {
|
||||
|
||||
t.Run("AgentWithoutClients", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory())
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
@ -77,7 +77,7 @@ func TestCoordinatorSingle(t *testing.T) {
|
||||
t.Run("AgentWithClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), database.NewPubsubInMemory())
|
||||
coordinator, err := tailnet.NewCoordinator(slogtest.Make(t, nil), pubsub.NewInMemory())
|
||||
require.NoError(t, err)
|
||||
defer coordinator.Close()
|
||||
|
||||
|
Reference in New Issue
Block a user