mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
Updates coder/customers#365 This PR updates our migration framework to run all migrations in a single transaction. This is the same behavior we had in v1 and ensures that failed migrations don't bring the whole deployment down. If a migration fails now, it will automatically be rolled back to the previous version, allowing the deployment to continue functioning.
213 lines
5.4 KiB
Go
213 lines
5.4 KiB
Go
package migrations
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"errors"
|
|
"io/fs"
|
|
"os"
|
|
|
|
"github.com/golang-migrate/migrate/v4"
|
|
"github.com/golang-migrate/migrate/v4/source"
|
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
"golang.org/x/xerrors"
|
|
)
|
|
|
|
//go:embed *.sql
|
|
var migrations embed.FS
|
|
|
|
func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
|
|
ctx := context.Background()
|
|
sourceDriver, err := iofs.New(migrations, ".")
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("create iofs: %w", err)
|
|
}
|
|
|
|
// migration_cursor is a v1 migration table. If this exists, we're on v1.
|
|
// Do no run v2 migrations on a v1 database!
|
|
row := db.QueryRowContext(ctx, "SELECT 1 FROM information_schema.tables WHERE table_schema = current_schema() AND table_name = 'migration_cursor';")
|
|
var v1Exists int
|
|
if row.Scan(&v1Exists) == nil {
|
|
return nil, nil, xerrors.New("currently connected to a Coder v1 database, aborting database setup")
|
|
}
|
|
|
|
dbDriver := &pgTxnDriver{ctx: context.Background(), db: db}
|
|
err = dbDriver.ensureVersionTable()
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("ensure version table: %w", err)
|
|
}
|
|
|
|
m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver)
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("new migrate instance: %w", err)
|
|
}
|
|
|
|
return sourceDriver, m, nil
|
|
}
|
|
|
|
// Up runs SQL migrations to ensure the database schema is up-to-date.
|
|
func Up(db *sql.DB) (retErr error) {
|
|
_, m, err := setup(db)
|
|
if err != nil {
|
|
return xerrors.Errorf("migrate setup: %w", err)
|
|
}
|
|
defer func() {
|
|
srcErr, dbErr := m.Close()
|
|
if retErr != nil {
|
|
return
|
|
}
|
|
if dbErr != nil {
|
|
retErr = dbErr
|
|
return
|
|
}
|
|
retErr = srcErr
|
|
}()
|
|
|
|
err = m.Up()
|
|
if err != nil {
|
|
if errors.Is(err, migrate.ErrNoChange) {
|
|
// It's OK if no changes happened!
|
|
return nil
|
|
}
|
|
|
|
return xerrors.Errorf("up: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Down runs all down SQL migrations.
|
|
func Down(db *sql.DB) error {
|
|
_, m, err := setup(db)
|
|
if err != nil {
|
|
return xerrors.Errorf("migrate setup: %w", err)
|
|
}
|
|
|
|
err = m.Down()
|
|
if err != nil {
|
|
if errors.Is(err, migrate.ErrNoChange) {
|
|
// It's OK if no changes happened!
|
|
return nil
|
|
}
|
|
|
|
return xerrors.Errorf("down: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EnsureClean checks whether all migrations for the current version have been
|
|
// applied, without making any changes to the database. If not, returns a
|
|
// non-nil error.
|
|
func EnsureClean(db *sql.DB) error {
|
|
sourceDriver, m, err := setup(db)
|
|
if err != nil {
|
|
return xerrors.Errorf("migrate setup: %w", err)
|
|
}
|
|
|
|
version, dirty, err := m.Version()
|
|
if err != nil {
|
|
return xerrors.Errorf("get migration version: %w", err)
|
|
}
|
|
|
|
if dirty {
|
|
return xerrors.Errorf("database has not been cleanly migrated")
|
|
}
|
|
|
|
// Verify that the database's migration version is "current" by checking
|
|
// that a migration with that version exists, but there is no next version.
|
|
err = CheckLatestVersion(sourceDriver, version)
|
|
if err != nil {
|
|
return xerrors.Errorf("database needs migration: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Returns nil if currentVersion corresponds to the latest available migration,
|
|
// otherwise an error explaining why not.
|
|
func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error {
|
|
// This is ugly, but seems like the only way to do it with the public
|
|
// interfaces provided by golang-migrate.
|
|
|
|
// Check that there is no later version
|
|
nextVersion, err := sourceDriver.Next(currentVersion)
|
|
if err == nil {
|
|
return xerrors.Errorf("current version is %d, but later version %d exists", currentVersion, nextVersion)
|
|
}
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return xerrors.Errorf("get next migration after %d: %w", currentVersion, err)
|
|
}
|
|
|
|
// Once we reach this point, we know that either currentVersion doesn't
|
|
// exist, or it has no successor (the return value from
|
|
// sourceDriver.Next() is the same in either case). So we need to check
|
|
// that either it's the first version, or it has a predecessor.
|
|
|
|
firstVersion, err := sourceDriver.First()
|
|
if err != nil {
|
|
// the total number of migrations should be non-zero, so this must be
|
|
// an actual error, not just a missing file
|
|
return xerrors.Errorf("get first migration: %w", err)
|
|
}
|
|
if firstVersion == currentVersion {
|
|
return nil
|
|
}
|
|
|
|
_, err = sourceDriver.Prev(currentVersion)
|
|
if err != nil {
|
|
return xerrors.Errorf("get previous migration: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Stepper returns a function that runs SQL migrations one step at a time.
|
|
//
|
|
// Stepper cannot be closed pre-emptively, it must be run to completion
|
|
// (or until an error is encountered).
|
|
func Stepper(db *sql.DB) (next func() (version uint, more bool, err error), err error) {
|
|
_, m, err := setup(db)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("migrate setup: %w", err)
|
|
}
|
|
|
|
return func() (version uint, more bool, err error) {
|
|
defer func() {
|
|
if !more {
|
|
srcErr, dbErr := m.Close()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if dbErr != nil {
|
|
err = dbErr
|
|
return
|
|
}
|
|
err = srcErr
|
|
}
|
|
}()
|
|
|
|
err = m.Steps(1)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, migrate.ErrNoChange):
|
|
// It's OK if no changes happened!
|
|
return 0, false, nil
|
|
case errors.Is(err, fs.ErrNotExist):
|
|
// This error is encountered at the of Steps when
|
|
// reading from embed.FS.
|
|
return 0, false, nil
|
|
}
|
|
|
|
return 0, false, xerrors.Errorf("Step: %w", err)
|
|
}
|
|
|
|
v, _, err := m.Version()
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
|
|
return v, true, nil
|
|
}, nil
|
|
}
|