feat(scripts): add script to check schema between migrations (#13037)

- migrations: allow passing in a custom migrate.FS
- gen/dump: extract some functions to dbtestutil
- scripts: write script to test migrations
This commit is contained in:
Cian Johnston
2024-04-23 12:43:14 +01:00
committed by GitHub
parent 81fcdf717b
commit e57ca3cdaa
5 changed files with 252 additions and 102 deletions

View File

@ -10,6 +10,7 @@ import (
"os/exec"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
"time"
@ -184,20 +185,21 @@ func DumpOnFailure(t testing.TB, connectionURL string) {
now := time.Now()
timeSuffix := fmt.Sprintf("%d%d%d%d%d%d", now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second())
outPath := filepath.Join(cwd, snakeCaseName+"."+timeSuffix+".test.sql")
dump, err := pgDump(connectionURL)
dump, err := PGDump(connectionURL)
if err != nil {
t.Errorf("dump on failure: failed to run pg_dump")
return
}
if err := os.WriteFile(outPath, filterDump(dump), 0o600); err != nil {
if err := os.WriteFile(outPath, normalizeDump(dump), 0o600); err != nil {
t.Errorf("dump on failure: failed to write: %s", err.Error())
return
}
t.Logf("Dumped database to %q due to failed test. I hope you find what you're looking for!", outPath)
}
// pgDump runs pg_dump against dbURL and returns the output.
func pgDump(dbURL string) ([]byte, error) {
// PGDump runs pg_dump against dbURL and returns the output.
// It is used by DumpOnFailure().
func PGDump(dbURL string) ([]byte, error) {
if _, err := exec.LookPath("pg_dump"); err != nil {
return nil, xerrors.Errorf("could not find pg_dump in path: %w", err)
}
@ -230,16 +232,79 @@ func pgDump(dbURL string) ([]byte, error) {
return stdout.Bytes(), nil
}
// Unfortunately, some insert expressions span multiple lines.
// The below may be over-permissive but better that than truncating data.
var insertExpr = regexp.MustCompile(`(?s)\bINSERT[^;]+;`)
const minimumPostgreSQLVersion = 13
func filterDump(dump []byte) []byte {
var buf bytes.Buffer
matches := insertExpr.FindAll(dump, -1)
for _, m := range matches {
_, _ = buf.Write(m)
_, _ = buf.WriteRune('\n')
// PGDumpSchemaOnly is for use by gen/dump only.
// It runs pg_dump against dbURL and sets a consistent timezone and encoding.
func PGDumpSchemaOnly(dbURL string) ([]byte, error) {
hasPGDump := false
if _, err := exec.LookPath("pg_dump"); err == nil {
out, err := exec.Command("pg_dump", "--version").Output()
if err == nil {
// Parse output:
// pg_dump (PostgreSQL) 14.5 (Ubuntu 14.5-0ubuntu0.22.04.1)
parts := strings.Split(string(out), " ")
if len(parts) > 2 {
version, err := strconv.Atoi(strings.Split(parts[2], ".")[0])
if err == nil && version >= minimumPostgreSQLVersion {
hasPGDump = true
}
}
}
}
return buf.Bytes()
cmdArgs := []string{
"pg_dump",
"--schema-only",
dbURL,
"--no-privileges",
"--no-owner",
"--no-privileges",
"--no-publication",
"--no-security-labels",
"--no-subscriptions",
"--no-tablespaces",
// We never want to manually generate
// queries executing against this table.
"--exclude-table=schema_migrations",
}
if !hasPGDump {
cmdArgs = append([]string{
"docker",
"run",
"--rm",
"--network=host",
fmt.Sprintf("gcr.io/coder-dev-1/postgres:%d", minimumPostgreSQLVersion),
}, cmdArgs...)
}
cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) //#nosec
cmd.Env = append(os.Environ(), []string{
"PGTZ=UTC",
"PGCLIENTENCODING=UTF8",
}...)
var output bytes.Buffer
cmd.Stdout = &output
cmd.Stderr = os.Stderr
err := cmd.Run()
if err != nil {
return nil, err
}
return normalizeDump(output.Bytes()), nil
}
func normalizeDump(schema []byte) []byte {
// Remove all comments.
schema = regexp.MustCompile(`(?im)^(--.*)$`).ReplaceAll(schema, []byte{})
// Public is implicit in the schema.
schema = regexp.MustCompile(`(?im)( |::|'|\()public\.`).ReplaceAll(schema, []byte(`$1`))
// Remove database settings.
schema = regexp.MustCompile(`(?im)^(SET.*;)`).ReplaceAll(schema, []byte(``))
// Remove select statements
schema = regexp.MustCompile(`(?im)^(SELECT.*;)`).ReplaceAll(schema, []byte(``))
// Removes multiple newlines.
schema = regexp.MustCompile(`(?im)\n{3,}`).ReplaceAll(schema, []byte("\n\n"))
return schema
}

View File

@ -1,21 +1,16 @@
package main
import (
"bytes"
"database/sql"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/migrations"
)
const minimumPostgreSQLVersion = 13
var preamble = []byte("-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.")
func main() {
connection, closeFn, err := dbtestutil.Open()
@ -28,95 +23,23 @@ func main() {
if err != nil {
panic(err)
}
defer db.Close()
err = migrations.Up(db)
if err != nil {
panic(err)
}
hasPGDump := false
if _, err = exec.LookPath("pg_dump"); err == nil {
out, err := exec.Command("pg_dump", "--version").Output()
if err == nil {
// Parse output:
// pg_dump (PostgreSQL) 14.5 (Ubuntu 14.5-0ubuntu0.22.04.1)
parts := strings.Split(string(out), " ")
if len(parts) > 2 {
version, err := strconv.Atoi(strings.Split(parts[2], ".")[0])
if err == nil && version >= minimumPostgreSQLVersion {
hasPGDump = true
}
}
}
}
cmdArgs := []string{
"pg_dump",
"--schema-only",
connection,
"--no-privileges",
"--no-owner",
// We never want to manually generate
// queries executing against this table.
"--exclude-table=schema_migrations",
}
if !hasPGDump {
cmdArgs = append([]string{
"docker",
"run",
"--rm",
"--network=host",
fmt.Sprintf("gcr.io/coder-dev-1/postgres:%d", minimumPostgreSQLVersion),
}, cmdArgs...)
}
cmd := exec.Command(cmdArgs[0], cmdArgs[1:]...) //#nosec
cmd.Env = append(os.Environ(), []string{
"PGTZ=UTC",
"PGCLIENTENCODING=UTF8",
}...)
var output bytes.Buffer
cmd.Stdout = &output
cmd.Stderr = os.Stderr
err = cmd.Run()
dumpBytes, err := dbtestutil.PGDumpSchemaOnly(connection)
if err != nil {
panic(err)
}
for _, sed := range []string{
// Remove all comments.
"/^--/d",
// Public is implicit in the schema.
"s/ public\\./ /g",
"s/::public\\./::/g",
"s/'public\\./'/g",
"s/(public\\./(/g",
// Remove database settings.
"s/SET .* = .*;//g",
// Remove select statements. These aren't useful
// to a reader of the dump.
"s/SELECT.*;//g",
// Removes multiple newlines.
"/^$/N;/^\\n$/D",
} {
cmd := exec.Command("sed", "-e", sed)
cmd.Stdin = bytes.NewReader(output.Bytes())
output = bytes.Buffer{}
cmd.Stdout = &output
cmd.Stderr = os.Stderr
err = cmd.Run()
if err != nil {
panic(err)
}
}
dump := fmt.Sprintf("-- Code generated by 'make coderd/database/generate'. DO NOT EDIT.\n%s", output.Bytes())
_, mainPath, _, ok := runtime.Caller(0)
if !ok {
panic("couldn't get caller path")
}
err = os.WriteFile(filepath.Join(mainPath, "..", "..", "..", "dump.sql"), []byte(dump), 0o600)
err = os.WriteFile(filepath.Join(mainPath, "..", "..", "..", "dump.sql"), append(preamble, dumpBytes...), 0o600)
if err != nil {
panic(err)
}

View File

@ -17,9 +17,12 @@ import (
//go:embed *.sql
var migrations embed.FS
func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
func setup(db *sql.DB, migs fs.FS) (source.Driver, *migrate.Migrate, error) {
if migs == nil {
migs = migrations
}
ctx := context.Background()
sourceDriver, err := iofs.New(migrations, ".")
sourceDriver, err := iofs.New(migs, ".")
if err != nil {
return nil, nil, xerrors.Errorf("create iofs: %w", err)
}
@ -47,8 +50,13 @@ func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
}
// Up runs SQL migrations to ensure the database schema is up-to-date.
func Up(db *sql.DB) (retErr error) {
_, m, err := setup(db)
func Up(db *sql.DB) error {
return UpWithFS(db, migrations)
}
// UpWithFS runs SQL migrations in the given fs.
func UpWithFS(db *sql.DB, migs fs.FS) (retErr error) {
_, m, err := setup(db, migs)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -79,7 +87,7 @@ func Up(db *sql.DB) (retErr error) {
// Down runs all down SQL migrations.
func Down(db *sql.DB) error {
_, m, err := setup(db)
_, m, err := setup(db, migrations)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -101,7 +109,7 @@ func Down(db *sql.DB) error {
// applied, without making any changes to the database. If not, returns a
// non-nil error.
func EnsureClean(db *sql.DB) error {
sourceDriver, m, err := setup(db)
sourceDriver, m, err := setup(db, migrations)
if err != nil {
return xerrors.Errorf("migrate setup: %w", err)
}
@ -167,7 +175,7 @@ func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error {
// Stepper cannot be closed pre-emptively, it must be run to completion
// (or until an error is encountered).
func Stepper(db *sql.DB) (next func() (version uint, more bool, err error), err error) {
_, m, err := setup(db)
_, m, err := setup(db, migrations)
if err != nil {
return nil, xerrors.Errorf("migrate setup: %w", err)
}