Files
coder/coderd/database/dbtestutil/postgres.go
Spike Curtis f0251dfc91 chore: retry postgres connection on reset by peer in tests (#18632)
Fixes https://github.com/coder/internal/issues/695

Retries initial connection to postgres in testing up to 3 seconds if we
see "reset by peer", which probably means that some other test proc just
started the container.

---------

Co-authored-by: Hugo Dutka <hugo@coder.com>
2025-06-27 13:03:32 +00:00

553 lines
17 KiB
Go

package dbtestutil
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/gofrs/flock"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/retry"
)
const postgresImage = "us-docker.pkg.dev/coder-v2-images-public/public/postgres"
type ConnectionParams struct {
Username string
Password string
Host string
Port string
DBName string
}
func (p ConnectionParams) DSN() string {
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", p.Username, p.Password, p.Host, p.Port, p.DBName)
}
// These variables are global because all tests share them.
var (
connectionParamsInitOnce sync.Once
defaultConnectionParams ConnectionParams
errDefaultConnectionParamsInit error
retryableErrSubstrings = []string{
"connection reset by peer",
}
noPostgresRunningErrSubstrings = []string{
"connection refused", // nothing is listening on the port
"No connection could be made", // Windows variant of the above
}
)
// initDefaultConnection initializes the default postgres connection parameters.
// It first checks if the database is running at localhost:5432. If it is, it will
// use that database. If it's not, it will start a new container and use that.
func initDefaultConnection(t TBSubset) error {
params := ConnectionParams{
Username: "postgres",
Password: "postgres",
Host: "127.0.0.1",
Port: "5432",
DBName: "postgres",
}
dsn := params.DSN()
// Helper closure to try opening and pinging the default Postgres instance.
// Used within a single retry loop that handles both retryable and permanent errors.
attemptConn := func() error {
db, err := sql.Open("postgres", dsn)
if err == nil {
err = db.Ping()
if closeErr := db.Close(); closeErr != nil {
return xerrors.Errorf("close db: %w", closeErr)
}
}
return err
}
var dbErr error
// Retry up to 3 seconds for temporary errors.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for r := retry.New(10*time.Millisecond, 500*time.Millisecond); r.Wait(ctx); {
dbErr = attemptConn()
if dbErr == nil {
break
}
errString := dbErr.Error()
if !containsAnySubstring(errString, retryableErrSubstrings) {
break
}
t.Logf("failed to connect to postgres, retrying: %s", errString)
}
// After the loop dbErr is the last connection error (if any).
if dbErr != nil && containsAnySubstring(dbErr.Error(), noPostgresRunningErrSubstrings) {
// If there's no database running on the default port, we'll start a
// postgres container. We won't be cleaning it up so it can be reused
// by subsequent tests. It'll keep on running until the user terminates
// it manually.
container, _, err := openContainer(t, DBContainerOptions{
Name: "coder-test-postgres",
Port: 5432,
})
if err != nil {
return xerrors.Errorf("open container: %w", err)
}
params.Host = container.Host
params.Port = container.Port
dsn = params.DSN()
// Retry connecting for at most 10 seconds.
// The fact that openContainer succeeded does not
// mean that port forwarding is ready.
for r := retry.New(100*time.Millisecond, 10*time.Second); r.Wait(context.Background()); {
db, connErr := sql.Open("postgres", dsn)
if connErr == nil {
connErr = db.Ping()
if closeErr := db.Close(); closeErr != nil {
return xerrors.Errorf("close db, container: %w", closeErr)
}
}
if connErr == nil {
break
}
t.Logf("failed to connect to postgres after starting container, may retry: %s", connErr.Error())
}
} else if dbErr != nil {
return xerrors.Errorf("open postgres connection: %w", dbErr)
}
defaultConnectionParams = params
return nil
}
type OpenOptions struct {
DBFrom *string
}
type OpenOption func(*OpenOptions)
// WithDBFrom sets the template database to use when creating a new database.
// Overrides the DB_FROM environment variable.
func WithDBFrom(dbFrom string) OpenOption {
return func(o *OpenOptions) {
o.DBFrom = &dbFrom
}
}
// TBSubset is a subset of the testing.TB interface.
// It allows to use dbtestutil.Open outside of tests.
type TBSubset interface {
Cleanup(func())
Helper()
Logf(format string, args ...any)
}
// Open creates a new PostgreSQL database instance.
// If there's a database running at localhost:5432, it will use that.
// Otherwise, it will start a new postgres container.
func Open(t TBSubset, opts ...OpenOption) (string, error) {
t.Helper()
connectionParamsInitOnce.Do(func() {
errDefaultConnectionParamsInit = initDefaultConnection(t)
})
if errDefaultConnectionParamsInit != nil {
return "", xerrors.Errorf("init default connection params: %w", errDefaultConnectionParamsInit)
}
openOptions := OpenOptions{}
for _, opt := range opts {
opt(&openOptions)
}
var (
username = defaultConnectionParams.Username
password = defaultConnectionParams.Password
host = defaultConnectionParams.Host
port = defaultConnectionParams.Port
)
// Use a time-based prefix to make it easier to find the database
// when debugging.
now := time.Now().Format("test_2006_01_02_15_04_05")
dbSuffix, err := cryptorand.StringCharset(cryptorand.Lower, 10)
if err != nil {
return "", xerrors.Errorf("generate db suffix: %w", err)
}
dbName := now + "_" + dbSuffix
// if empty createDatabaseFromTemplate will create a new template db
templateDBName := os.Getenv("DB_FROM")
if openOptions.DBFrom != nil {
templateDBName = *openOptions.DBFrom
}
if err = createDatabaseFromTemplate(t, defaultConnectionParams, dbName, templateDBName); err != nil {
return "", xerrors.Errorf("create database: %w", err)
}
t.Cleanup(func() {
cleanupDbURL := defaultConnectionParams.DSN()
cleanupConn, err := sql.Open("postgres", cleanupDbURL)
if err != nil {
t.Logf("cleanup database %q: failed to connect to postgres: %s\n", dbName, err.Error())
return
}
defer func() {
if err := cleanupConn.Close(); err != nil {
t.Logf("cleanup database %q: failed to close connection: %s\n", dbName, err.Error())
}
}()
_, err = cleanupConn.Exec("DROP DATABASE " + dbName + ";")
if err != nil {
t.Logf("failed to clean up database %q: %s\n", dbName, err.Error())
return
}
})
dsn := ConnectionParams{
Username: username,
Password: password,
Host: host,
Port: port,
DBName: dbName,
}.DSN()
return dsn, nil
}
// createDatabaseFromTemplate creates a new database from a template database.
// If templateDBName is empty, it will create a new template database based on
// the current migrations, and name it "tpl_<migrations_hash>". Or if it's
// already been created, it will use that.
func createDatabaseFromTemplate(t TBSubset, connParams ConnectionParams, newDBName string, templateDBName string) error {
t.Helper()
dbURL := connParams.DSN()
db, err := sql.Open("postgres", dbURL)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
defer func() {
if err := db.Close(); err != nil {
t.Logf("create database from template: failed to close connection: %s\n", err.Error())
}
}()
emptyTemplateDBName := templateDBName == ""
if emptyTemplateDBName {
templateDBName = fmt.Sprintf("tpl_%s", migrations.GetMigrationsHash()[:32])
}
_, err = db.Exec("CREATE DATABASE " + newDBName + " WITH TEMPLATE " + templateDBName)
if err == nil {
// Template database already exists and we successfully created the new database.
return nil
}
tplDbDoesNotExistOccurred := strings.Contains(err.Error(), "template database") && strings.Contains(err.Error(), "does not exist")
if (tplDbDoesNotExistOccurred && !emptyTemplateDBName) || !tplDbDoesNotExistOccurred {
// First and case: user passed a templateDBName that doesn't exist.
// Second and case: some other error.
return xerrors.Errorf("create db with template: %w", err)
}
if !emptyTemplateDBName {
// sanity check
panic("templateDBName is not empty. there's a bug in the code above")
}
// The templateDBName is empty, so we need to create the template database.
// We will use a tx to obtain a lock, so another test or process doesn't race with us.
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
return xerrors.Errorf("begin tx: %w", err)
}
defer func() {
err := tx.Rollback()
if err != nil && !errors.Is(err, sql.ErrTxDone) {
t.Logf("create database from template: failed to rollback tx: %s\n", err.Error())
}
}()
// 2137 is an arbitrary number. We just need a lock that is unique to creating
// the template database.
_, err = tx.Exec("SELECT pg_advisory_xact_lock(2137)")
if err != nil {
return xerrors.Errorf("acquire lock: %w", err)
}
// Someone else might have created the template db while we were waiting.
tplDbExistsRes, err := tx.Query("SELECT 1 FROM pg_database WHERE datname = $1", templateDBName)
if err != nil {
return xerrors.Errorf("check if db exists: %w", err)
}
tplDbAlreadyExists := tplDbExistsRes.Next()
if err := tplDbExistsRes.Close(); err != nil {
return xerrors.Errorf("close tpl db exists res: %w", err)
}
if !tplDbAlreadyExists {
// We will use a temporary template database to avoid race conditions. We will
// rename it to the real template database name after we're sure it was fully
// initialized.
// It's dropped here to ensure that if a previous run of this function failed
// midway, we don't encounter issues with the temporary database still existing.
tmpTemplateDBName := "tmp_" + templateDBName
// We're using db instead of tx here because you can't run `DROP DATABASE` inside
// a transaction.
if _, err := db.Exec("DROP DATABASE IF EXISTS " + tmpTemplateDBName); err != nil {
return xerrors.Errorf("drop tmp template db: %w", err)
}
if _, err := db.Exec("CREATE DATABASE " + tmpTemplateDBName); err != nil {
return xerrors.Errorf("create tmp template db: %w", err)
}
tplDbURL := ConnectionParams{
Username: connParams.Username,
Password: connParams.Password,
Host: connParams.Host,
Port: connParams.Port,
DBName: tmpTemplateDBName,
}.DSN()
tplDb, err := sql.Open("postgres", tplDbURL)
if err != nil {
return xerrors.Errorf("connect to template db: %w", err)
}
defer func() {
if err := tplDb.Close(); err != nil {
t.Logf("create database from template: failed to close template db: %s\n", err.Error())
}
}()
if err := migrations.Up(tplDb); err != nil {
return xerrors.Errorf("migrate template db: %w", err)
}
if err := tplDb.Close(); err != nil {
return xerrors.Errorf("close template db: %w", err)
}
if _, err := db.Exec("ALTER DATABASE " + tmpTemplateDBName + " RENAME TO " + templateDBName); err != nil {
return xerrors.Errorf("rename tmp template db: %w", err)
}
}
// Try to create the database again now that a template exists.
if _, err = db.Exec("CREATE DATABASE " + newDBName + " WITH TEMPLATE " + templateDBName); err != nil {
return xerrors.Errorf("create db with template after migrations: %w", err)
}
if err = tx.Commit(); err != nil {
return xerrors.Errorf("commit tx: %w", err)
}
return nil
}
type DBContainerOptions struct {
Port int
Name string
}
type container struct {
Resource *dockertest.Resource
Pool *dockertest.Pool
Host string
Port string
}
// OpenContainer creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic
// to that port to the database. If port is zero, allocate a free port from the OS.
// If name is set, we'll ensure that only one container is started with that name. If it's already running, we'll use that.
// Otherwise, we'll start a new container.
func openContainer(t TBSubset, opts DBContainerOptions) (container, func(), error) {
if opts.Name != "" {
// We only want to start the container once per unique name,
// so we take an inter-process lock to avoid concurrent test runs
// racing with us.
nameHash := sha256.Sum256([]byte(opts.Name))
nameHashStr := hex.EncodeToString(nameHash[:])
lock := flock.New(filepath.Join(os.TempDir(), "coder-postgres-container-"+nameHashStr[:8]))
if err := lock.Lock(); err != nil {
return container{}, nil, xerrors.Errorf("lock: %w", err)
}
defer func() {
err := lock.Unlock()
if err != nil {
t.Logf("create database from template: failed to unlock: %s\n", err.Error())
}
}()
}
pool, err := dockertest.NewPool("")
if err != nil {
return container{}, nil, xerrors.Errorf("create pool: %w", err)
}
var resource *dockertest.Resource
var tempDir string
if opts.Name != "" {
// If the container already exists, we'll use it.
resource, _ = pool.ContainerByName(opts.Name)
}
if resource == nil {
tempDir, err = os.MkdirTemp(os.TempDir(), "postgres")
if err != nil {
return container{}, nil, xerrors.Errorf("create tempdir: %w", err)
}
runOptions := dockertest.RunOptions{
Repository: postgresImage,
Tag: strconv.Itoa(minimumPostgreSQLVersion),
Env: []string{
"POSTGRES_PASSWORD=postgres",
"POSTGRES_USER=postgres",
"POSTGRES_DB=postgres",
// The location for temporary database files!
"PGDATA=/tmp",
"listen_addresses = '*'",
},
PortBindings: map[docker.Port][]docker.PortBinding{
"5432/tcp": {{
// Manually specifying a host IP tells Docker just to use an IPV4 address.
// If we don't do this, we hit a fun bug:
// https://github.com/moby/moby/issues/42442
// where the ipv4 and ipv6 ports might be _different_ and collide with other running docker containers.
HostIP: "0.0.0.0",
HostPort: strconv.FormatInt(int64(opts.Port), 10),
}},
},
Mounts: []string{
// The postgres image has a VOLUME parameter in it's image.
// If we don't mount at this point, Docker will allocate a
// volume for this directory.
//
// This isn't used anyways, since we override PGDATA.
fmt.Sprintf("%s:/var/lib/postgresql/data", tempDir),
},
Cmd: []string{"-c", "max_connections=1000"},
}
if opts.Name != "" {
runOptions.Name = opts.Name
}
resource, err = pool.RunWithOptions(&runOptions, func(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
config.Tmpfs = map[string]string{
"/tmp": "rw",
}
})
if err != nil {
return container{}, nil, xerrors.Errorf("could not start resource: %w", err)
}
}
hostAndPort := resource.GetHostPort("5432/tcp")
host, port, err := net.SplitHostPort(hostAndPort)
if err != nil {
return container{}, nil, xerrors.Errorf("split host and port: %w", err)
}
for r := retry.New(50*time.Millisecond, 15*time.Second); r.Wait(context.Background()); {
stdout := &strings.Builder{}
stderr := &strings.Builder{}
_, err = resource.Exec([]string{"pg_isready", "-h", "127.0.0.1"}, dockertest.ExecOptions{
StdOut: stdout,
StdErr: stderr,
})
if err == nil {
break
}
}
if err != nil {
return container{}, nil, xerrors.Errorf("pg_isready: %w", err)
}
return container{
Host: host,
Port: port,
Resource: resource,
Pool: pool,
}, func() {
_ = pool.Purge(resource)
if tempDir != "" {
_ = os.RemoveAll(tempDir)
}
}, nil
}
// OpenContainerized creates a new PostgreSQL server using a Docker container. If port is nonzero, forward host traffic
// to that port to the database. If port is zero, allocate a free port from the OS.
// The user is responsible for calling the returned cleanup function.
func OpenContainerized(t TBSubset, opts DBContainerOptions) (string, func(), error) {
container, containerCleanup, err := openContainer(t, opts)
if err != nil {
return "", nil, xerrors.Errorf("open container: %w", err)
}
defer func() {
if err != nil {
containerCleanup()
}
}()
dbURL := ConnectionParams{
Username: "postgres",
Password: "postgres",
Host: container.Host,
Port: container.Port,
DBName: "postgres",
}.DSN()
// Docker should hard-kill the container after 120 seconds.
err = container.Resource.Expire(120)
if err != nil {
return "", nil, xerrors.Errorf("expire resource: %w", err)
}
container.Pool.MaxWait = 120 * time.Second
// Record the error that occurs during the retry.
// The 'pool' pkg hardcodes a deadline error devoid
// of any useful context.
var retryErr error
err = container.Pool.Retry(func() error {
db, err := sql.Open("postgres", dbURL)
if err != nil {
retryErr = xerrors.Errorf("open postgres: %w", err)
return retryErr
}
defer db.Close()
err = db.Ping()
if err != nil {
retryErr = xerrors.Errorf("ping postgres: %w", err)
return retryErr
}
err = migrations.Up(db)
if err != nil {
retryErr = xerrors.Errorf("migrate db: %w", err)
// Only try to migrate once.
return backoff.Permanent(retryErr)
}
return nil
})
if err != nil {
return "", nil, retryErr
}
return dbURL, containerCleanup, nil
}
func containsAnySubstring(s string, substrings []string) bool {
for _, substr := range substrings {
if strings.Contains(s, substr) {
return true
}
}
return false
}