mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
fix: use authenticated urls for pubsub (#14261)
This commit is contained in:
@ -10,7 +10,10 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
type awsIamRdsDriver struct {
|
||||
@ -18,7 +21,10 @@ type awsIamRdsDriver struct {
|
||||
cfg aws.Config
|
||||
}
|
||||
|
||||
var _ driver.Driver = &awsIamRdsDriver{}
|
||||
var (
|
||||
_ driver.Driver = &awsIamRdsDriver{}
|
||||
_ database.ConnectorCreator = &awsIamRdsDriver{}
|
||||
)
|
||||
|
||||
// Register initializes and registers our aws iam rds wrapped database driver.
|
||||
func Register(ctx context.Context, parentName string) (string, error) {
|
||||
@ -65,6 +71,16 @@ func (d *awsIamRdsDriver) Open(name string) (driver.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Connector returns a driver.Connector that fetches a new authentication token for each connection.
|
||||
func (d *awsIamRdsDriver) Connector(name string) (driver.Connector, error) {
|
||||
connector := &connector{
|
||||
url: name,
|
||||
cfg: d.cfg,
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) {
|
||||
nURL, err := url.Parse(dbURL)
|
||||
if err != nil {
|
||||
@ -82,3 +98,37 @@ func getAuthenticatedURL(cfg aws.Config, dbURL string) (string, error) {
|
||||
|
||||
return nURL.String(), nil
|
||||
}
|
||||
|
||||
type connector struct {
|
||||
url string
|
||||
cfg aws.Config
|
||||
dialer pq.Dialer
|
||||
}
|
||||
|
||||
var _ database.DialerConnector = &connector{}
|
||||
|
||||
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
nURL, err := getAuthenticatedURL(c.cfg, c.url)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("assigning authentication token to url: %w", err)
|
||||
}
|
||||
|
||||
nc, err := pq.NewConnector(nURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("creating new connector: %w", err)
|
||||
}
|
||||
|
||||
if c.dialer != nil {
|
||||
nc.Dialer(c.dialer)
|
||||
}
|
||||
|
||||
return nc.Connect(ctx)
|
||||
}
|
||||
|
||||
func (*connector) Driver() driver.Driver {
|
||||
return &pq.Driver{}
|
||||
}
|
||||
|
||||
func (c *connector) Dialer(dialer pq.Dialer) {
|
||||
c.dialer = dialer
|
||||
}
|
||||
|
@ -7,10 +7,11 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/cli"
|
||||
awsrdsiam "github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@ -22,13 +23,15 @@ func TestDriver(t *testing.T) {
|
||||
// export DBAWSIAMRDS_TEST_URL="postgres://user@host:5432/dbname";
|
||||
url := os.Getenv("DBAWSIAMRDS_TEST_URL")
|
||||
if url == "" {
|
||||
t.Log("skipping test; no DBAWSIAMRDS_TEST_URL set")
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
sqlDriver, err := awsrdsiam.Register(ctx, "postgres")
|
||||
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := cli.ConnectToPostgres(ctx, slogtest.Make(t, nil), sqlDriver, url)
|
||||
@ -47,4 +50,23 @@ func TestDriver(t *testing.T) {
|
||||
var one int
|
||||
require.NoError(t, i.Scan(&one))
|
||||
require.Equal(t, 1, one)
|
||||
|
||||
ps, err := pubsub.New(ctx, logger, db, url)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotChan := make(chan struct{})
|
||||
subCancel, err := ps.Subscribe("test", func(_ context.Context, _ []byte) {
|
||||
close(gotChan)
|
||||
})
|
||||
defer subCancel()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ps.Publish("test", []byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-gotChan:
|
||||
case <-ctx.Done():
|
||||
require.Fail(t, "timed out waiting for message")
|
||||
}
|
||||
}
|
||||
|
19
coderd/database/connector.go
Normal file
19
coderd/database/connector.go
Normal file
@ -0,0 +1,19 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// ConnectorCreator is a driver.Driver that can create a driver.Connector.
|
||||
type ConnectorCreator interface {
|
||||
driver.Driver
|
||||
Connector(name string) (driver.Connector, error)
|
||||
}
|
||||
|
||||
// DialerConnector is a driver.Connector that can set a pq.Dialer.
|
||||
type DialerConnector interface {
|
||||
driver.Connector
|
||||
Dialer(dialer pq.Dialer)
|
||||
}
|
79
coderd/database/dbtestutil/driver.go
Normal file
79
coderd/database/dbtestutil/driver.go
Normal file
@ -0,0 +1,79 @@
|
||||
package dbtestutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
var _ database.DialerConnector = &Connector{}
|
||||
|
||||
type Connector struct {
|
||||
name string
|
||||
driver *Driver
|
||||
dialer pq.Dialer
|
||||
}
|
||||
|
||||
func (c *Connector) Connect(_ context.Context) (driver.Conn, error) {
|
||||
if c.dialer != nil {
|
||||
conn, err := pq.DialOpen(c.dialer, c.name)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to dial open connection: %w", err)
|
||||
}
|
||||
|
||||
c.driver.Connections <- conn
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
conn, err := pq.Driver{}.Open(c.name)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to open connection: %w", err)
|
||||
}
|
||||
|
||||
c.driver.Connections <- conn
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Connector) Driver() driver.Driver {
|
||||
return c.driver
|
||||
}
|
||||
|
||||
func (c *Connector) Dialer(dialer pq.Dialer) {
|
||||
c.dialer = dialer
|
||||
}
|
||||
|
||||
type Driver struct {
|
||||
Connections chan driver.Conn
|
||||
}
|
||||
|
||||
func NewDriver() *Driver {
|
||||
return &Driver{
|
||||
Connections: make(chan driver.Conn, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Driver) Connector(name string) (driver.Connector, error) {
|
||||
return &Connector{
|
||||
name: name,
|
||||
driver: d,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||
c, err := d.Connector(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.Connect(context.Background())
|
||||
}
|
||||
|
||||
func (d *Driver) Close() {
|
||||
close(d.Connections)
|
||||
}
|
@ -3,6 +3,7 @@ package pubsub
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@ -15,6 +16,8 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
@ -432,9 +435,35 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
// pq.defaultDialer uses a zero net.Dialer as well.
|
||||
d: net.Dialer{},
|
||||
}
|
||||
connector driver.Connector
|
||||
err error
|
||||
)
|
||||
|
||||
// Create a custom connector if the database driver supports it.
|
||||
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
|
||||
if ok {
|
||||
connector, err = connectorCreator.Connector(connectURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create custom connector: %w", err)
|
||||
}
|
||||
} else {
|
||||
// use the default pq connector otherwise
|
||||
connector, err = pq.NewConnector(connectURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create pq connector: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the dialer if the connector supports it.
|
||||
dc, ok := connector.(database.DialerConnector)
|
||||
if !ok {
|
||||
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
|
||||
} else {
|
||||
dc.Dialer(dialer)
|
||||
}
|
||||
|
||||
p.pgListener = pqListenerShim{
|
||||
Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
|
||||
Listener: pq.NewConnectorListener(connector, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) {
|
||||
switch t {
|
||||
case pq.ListenerEventConnected:
|
||||
p.logger.Info(ctx, "pubsub connected to postgres")
|
||||
@ -583,8 +612,8 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) {
|
||||
}
|
||||
|
||||
// New creates a new Pubsub implementation using a PostgreSQL connection.
|
||||
func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) {
|
||||
p := newWithoutListener(logger, database)
|
||||
func New(startCtx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (*PGPubsub, error) {
|
||||
p := newWithoutListener(logger, db)
|
||||
if err := p.startListener(startCtx, connectURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -594,11 +623,11 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect
|
||||
}
|
||||
|
||||
// newWithoutListener creates a new PGPubsub without creating the pqListener.
|
||||
func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub {
|
||||
func newWithoutListener(logger slog.Logger, db *sql.DB) *PGPubsub {
|
||||
return &PGPubsub{
|
||||
logger: logger,
|
||||
listenDone: make(chan struct{}),
|
||||
db: database,
|
||||
db: db,
|
||||
queues: make(map[string]map[uuid.UUID]*msgQueue),
|
||||
latencyMeasurer: NewLatencyMeasurer(logger.Named("latency-measurer")),
|
||||
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -51,7 +52,7 @@ func TestPGPubsub_Metrics(t *testing.T) {
|
||||
event := "test"
|
||||
data := "testing"
|
||||
messageChannel := make(chan []byte)
|
||||
unsub0, err := uut.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
unsub0, err := uut.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -86,7 +87,7 @@ func TestPGPubsub_Metrics(t *testing.T) {
|
||||
for i := range colossalData {
|
||||
colossalData[i] = 'q'
|
||||
}
|
||||
unsub1, err := uut.Subscribe(event, func(ctx context.Context, message []byte) {
|
||||
unsub1, err := uut.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageChannel <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -119,3 +120,74 @@ func TestPGPubsub_Metrics(t *testing.T) {
|
||||
!testutil.PromCounterGathered(t, metrics, "coder_pubsub_latency_measure_errs_total")
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestPGPubsubDriver(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
|
||||
connectionURL, closePg, err := dbtestutil.Open()
|
||||
require.NoError(t, err)
|
||||
defer closePg()
|
||||
|
||||
// use a separate subber and pubber so we can keep track of listener connections
|
||||
db, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
pubber, err := pubsub.New(ctx, logger, db, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer pubber.Close()
|
||||
|
||||
// use a connector that sends us the connections for the subber
|
||||
subDriver := dbtestutil.NewDriver()
|
||||
defer subDriver.Close()
|
||||
tconn, err := subDriver.Connector(connectionURL)
|
||||
require.NoError(t, err)
|
||||
tcdb := sql.OpenDB(tconn)
|
||||
subber, err := pubsub.New(ctx, logger, tcdb, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer subber.Close()
|
||||
|
||||
// test that we can publish and subscribe
|
||||
gotChan := make(chan struct{}, 1)
|
||||
defer close(gotChan)
|
||||
subCancel, err := subber.Subscribe("test", func(_ context.Context, _ []byte) {
|
||||
gotChan <- struct{}{}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer subCancel()
|
||||
|
||||
// send a message
|
||||
err = pubber.Publish("test", []byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait for the message
|
||||
_ = testutil.RequireRecvCtx(ctx, t, gotChan)
|
||||
|
||||
// read out first connection
|
||||
firstConn := testutil.RequireRecvCtx(ctx, t, subDriver.Connections)
|
||||
|
||||
// drop the underlying connection being used by the pubsub
|
||||
// the pq.Listener should reconnect and repopulate it's listeners
|
||||
// so old subscriptions should still work
|
||||
err = firstConn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait for the reconnect
|
||||
_ = testutil.RequireRecvCtx(ctx, t, subDriver.Connections)
|
||||
// we need to sleep because the raw connection notification
|
||||
// is sent before the pq.Listener can reestablish it's listeners
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// ensure our old subscription still fires
|
||||
err = pubber.Publish("test", []byte("hello-again"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait for the message on the old subscription
|
||||
_ = testutil.RequireRecvCtx(ctx, t, gotChan)
|
||||
}
|
||||
|
Reference in New Issue
Block a user