feat(scripts): add script to check schema between migrations (#13037)

- migrations: allow passing in a custom migrate.FS
- gen/dump: extract some functions to dbtestutil
- scripts: write script to test migrations
This commit is contained in:
Cian Johnston
2024-04-23 12:43:14 +01:00
committed by GitHub
parent 81fcdf717b
commit e57ca3cdaa
5 changed files with 252 additions and 102 deletions

View File

@ -17,9 +17,12 @@ import (
//go:embed *.sql
var migrations embed.FS
func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
func setup(db *sql.DB, migs fs.FS) (source.Driver, *migrate.Migrate, error) {
if migs == nil {
migs = migrations
}
ctx := context.Background()
sourceDriver, err := iofs.New(migrations, ".")
sourceDriver, err := iofs.New(migs, ".")
if err != nil {
return nil, nil, xerrors.Errorf("create iofs: %w", err)
}
@ -47,8 +50,13 @@ func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
}
// Up runs SQL migrations to ensure the database schema is up-to-date.
func Up(db *sql.DB) (retErr error) {
_, m, err := setup(db)
func Up(db *sql.DB) error {
return UpWithFS(db, migrations)
}
// UpWithFS runs SQL migrations in the given fs.
func UpWithFS(db *sql.DB, migs fs.FS) (retErr error) {
_, m, err := setup(db, migs)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -79,7 +87,7 @@ func Up(db *sql.DB) (retErr error) {
// Down runs all down SQL migrations.
func Down(db *sql.DB) error {
_, m, err := setup(db)
_, m, err := setup(db, migrations)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -101,7 +109,7 @@ func Down(db *sql.DB) error {
// applied, without making any changes to the database. If not, returns a
// non-nil error.
func EnsureClean(db *sql.DB) error {
sourceDriver, m, err := setup(db)
sourceDriver, m, err := setup(db, migrations)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -167,7 +175,7 @@ func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error {
// Stepper cannot be closed pre-emptively, it must be run to completion
// (or until an error is encountered).
func Stepper(db *sql.DB) (next func() (version uint, more bool, err error), err error) {
_, m, err := setup(db)
_, m, err := setup(db, migrations)
if err != nil {
return nil, xerrors.Errorf("migrate setup: %w", err)
}