Files
coder/coderd/database/migrations/migrate.go
Hugo Dutka 1bfa7d42e8 chore: add postgres template caching for tests (#15336)
This PR is the first in a series aimed at closing
[#15109](https://github.com/coder/coder/issues/15109).

### Changes

- **Template Database Creation:**  
`dbtestutil.Open` now has the ability to create a template database if
none is provided via `DB_FROM`. The template database’s name is derived
from a hash of the migration files, ensuring that it can be reused
across tests and is automatically updated whenever migrations change.

- **Optimized Database Handling:**  
Previously, `dbtestutil.Open` would spin up a new container for each
test when `DB_FROM` was unset. Now, it first checks for an active
PostgreSQL instance on `localhost:5432`. If none is found, it creates a
single container that remains available for subsequent tests,
eliminating repeated container startups.

These changes address the long individual test times (10+ seconds)
reported by some users, likely due to the time Docker took to start and
complete migrations.
2024-11-04 17:23:31 +01:00

276 lines
7.1 KiB
Go

package migrations
import (
"context"
"crypto/sha256"
"database/sql"
"embed"
"errors"
"fmt"
"io/fs"
"os"
"sort"
"strings"
"sync"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/iofs"
"golang.org/x/xerrors"
)
//go:embed *.sql
var migrations embed.FS
var (
migrationsHash string
migrationsHashOnce sync.Once
)
// A migrations hash is a sha256 hash of the contents and names
// of the migrations sorted by filename.
func calculateMigrationsHash(migrationsFs embed.FS) (string, error) {
files, err := migrationsFs.ReadDir(".")
if err != nil {
return "", xerrors.Errorf("read migrations directory: %w", err)
}
sortedFiles := make([]fs.DirEntry, len(files))
copy(sortedFiles, files)
sort.Slice(sortedFiles, func(i, j int) bool {
return sortedFiles[i].Name() < sortedFiles[j].Name()
})
var builder strings.Builder
for _, file := range sortedFiles {
if _, err := builder.WriteString(file.Name()); err != nil {
return "", xerrors.Errorf("write migration file name %q: %w", file.Name(), err)
}
content, err := migrationsFs.ReadFile(file.Name())
if err != nil {
return "", xerrors.Errorf("read migration file %q: %w", file.Name(), err)
}
if _, err := builder.Write(content); err != nil {
return "", xerrors.Errorf("write migration file content %q: %w", file.Name(), err)
}
}
hash := sha256.New()
if _, err := hash.Write([]byte(builder.String())); err != nil {
return "", xerrors.Errorf("write to hash: %w", err)
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}
func GetMigrationsHash() string {
migrationsHashOnce.Do(func() {
hash, err := calculateMigrationsHash(migrations)
if err != nil {
panic(err)
}
migrationsHash = hash
})
return migrationsHash
}
func setup(db *sql.DB, migs fs.FS) (source.Driver, *migrate.Migrate, error) {
if migs == nil {
migs = migrations
}
ctx := context.Background()
sourceDriver, err := iofs.New(migs, ".")
if err != nil {
return nil, nil, xerrors.Errorf("create iofs: %w", err)
}
// migration_cursor is a v1 migration table. If this exists, we're on v1.
// Do no run v2 migrations on a v1 database!
row := db.QueryRowContext(ctx, "SELECT 1 FROM information_schema.tables WHERE table_schema = current_schema() AND table_name = 'migration_cursor';")
var v1Exists int
if row.Scan(&v1Exists) == nil {
return nil, nil, xerrors.New("currently connected to a Coder v1 database, aborting database setup")
}
dbDriver := &pgTxnDriver{ctx: context.Background(), db: db}
err = dbDriver.ensureVersionTable()
if err != nil {
return nil, nil, xerrors.Errorf("ensure version table: %w", err)
}
m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver)
if err != nil {
return nil, nil, xerrors.Errorf("new migrate instance: %w", err)
}
return sourceDriver, m, nil
}
// Up runs SQL migrations to ensure the database schema is up-to-date.
func Up(db *sql.DB) error {
return UpWithFS(db, migrations)
}
// UpWithFS runs SQL migrations in the given fs.
func UpWithFS(db *sql.DB, migs fs.FS) (retErr error) {
_, m, err := setup(db, migs)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
defer func() {
srcErr, dbErr := m.Close()
if retErr != nil {
return
}
if dbErr != nil {
retErr = dbErr
return
}
retErr = srcErr
}()
err = m.Up()
if err != nil {
if errors.Is(err, migrate.ErrNoChange) {
// It's OK if no changes happened!
return nil
}
return xerrors.Errorf("up: %w", err)
}
return nil
}
// Down runs all down SQL migrations.
func Down(db *sql.DB) error {
_, m, err := setup(db, migrations)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
err = m.Down()
if err != nil {
if errors.Is(err, migrate.ErrNoChange) {
// It's OK if no changes happened!
return nil
}
return xerrors.Errorf("down: %w", err)
}
return nil
}
// EnsureClean checks whether all migrations for the current version have been
// applied, without making any changes to the database. If not, returns a
// non-nil error.
func EnsureClean(db *sql.DB) error {
sourceDriver, m, err := setup(db, migrations)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
version, dirty, err := m.Version()
if err != nil {
return xerrors.Errorf("get migration version: %w", err)
}
if dirty {
return xerrors.Errorf("database has not been cleanly migrated")
}
// Verify that the database's migration version is "current" by checking
// that a migration with that version exists, but there is no next version.
err = CheckLatestVersion(sourceDriver, version)
if err != nil {
return xerrors.Errorf("database needs migration: %w", err)
}
return nil
}
// Returns nil if currentVersion corresponds to the latest available migration,
// otherwise an error explaining why not.
func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error {
// This is ugly, but seems like the only way to do it with the public
// interfaces provided by golang-migrate.
// Check that there is no later version
nextVersion, err := sourceDriver.Next(currentVersion)
if err == nil {
return xerrors.Errorf("current version is %d, but later version %d exists", currentVersion, nextVersion)
}
if !errors.Is(err, os.ErrNotExist) {
return xerrors.Errorf("get next migration after %d: %w", currentVersion, err)
}
// Once we reach this point, we know that either currentVersion doesn't
// exist, or it has no successor (the return value from
// sourceDriver.Next() is the same in either case). So we need to check
// that either it's the first version, or it has a predecessor.
firstVersion, err := sourceDriver.First()
if err != nil {
// the total number of migrations should be non-zero, so this must be
// an actual error, not just a missing file
return xerrors.Errorf("get first migration: %w", err)
}
if firstVersion == currentVersion {
return nil
}
_, err = sourceDriver.Prev(currentVersion)
if err != nil {
return xerrors.Errorf("get previous migration: %w", err)
}
return nil
}
// Stepper returns a function that runs SQL migrations one step at a time.
//
// Stepper cannot be closed pre-emptively, it must be run to completion
// (or until an error is encountered).
func Stepper(db *sql.DB) (next func() (version uint, more bool, err error), err error) {
_, m, err := setup(db, migrations)
if err != nil {
return nil, xerrors.Errorf("migrate setup: %w", err)
}
return func() (version uint, more bool, err error) {
defer func() {
if !more {
srcErr, dbErr := m.Close()
if err != nil {
return
}
if dbErr != nil {
err = dbErr
return
}
err = srcErr
}
}()
err = m.Steps(1)
if err != nil {
switch {
case errors.Is(err, migrate.ErrNoChange):
// It's OK if no changes happened!
return 0, false, nil
case errors.Is(err, fs.ErrNotExist):
// This error is encountered at the of Steps when
// reading from embed.FS.
return 0, false, nil
}
return 0, false, xerrors.Errorf("Step: %w", err)
}
v, _, err := m.Version()
if err != nil {
return 0, false, err
}
return v, true, nil
}, nil
}