mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: Add reset-password command (#1380)
* allow non-destructively checking if database needs to be migrated * feat: Add reset-password command * fix linter errors * clean up reset-password usage prompt * Add confirmation to reset-password command * Ping database before checking migration, to improve error message
This commit is contained in:
@ -4,9 +4,11 @@ import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
"github.com/golang-migrate/migrate/v4/source"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
@ -14,28 +16,28 @@ import (
|
||||
//go:embed migrations/*.sql
|
||||
var migrations embed.FS
|
||||
|
||||
func migrateSetup(db *sql.DB) (*migrate.Migrate, error) {
|
||||
func migrateSetup(db *sql.DB) (source.Driver, *migrate.Migrate, error) {
|
||||
sourceDriver, err := iofs.New(migrations, "migrations")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create iofs: %w", err)
|
||||
return nil, nil, xerrors.Errorf("create iofs: %w", err)
|
||||
}
|
||||
|
||||
dbDriver, err := postgres.WithInstance(db, &postgres.Config{})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("wrap postgres connection: %w", err)
|
||||
return nil, nil, xerrors.Errorf("wrap postgres connection: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("", sourceDriver, "", dbDriver)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("new migrate instance: %w", err)
|
||||
return nil, nil, xerrors.Errorf("new migrate instance: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
return sourceDriver, m, nil
|
||||
}
|
||||
|
||||
// MigrateUp runs SQL migrations to ensure the database schema is up-to-date.
|
||||
func MigrateUp(db *sql.DB) error {
|
||||
m, err := migrateSetup(db)
|
||||
_, m, err := migrateSetup(db)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("migrate setup: %w", err)
|
||||
}
|
||||
@ -55,7 +57,7 @@ func MigrateUp(db *sql.DB) error {
|
||||
|
||||
// MigrateDown runs all down SQL migrations.
|
||||
func MigrateDown(db *sql.DB) error {
|
||||
m, err := migrateSetup(db)
|
||||
_, m, err := migrateSetup(db)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("migrate setup: %w", err)
|
||||
}
|
||||
@ -72,3 +74,68 @@ func MigrateDown(db *sql.DB) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureClean checks whether all migrations for the current version have been
|
||||
// applied, without making any changes to the database. If not, returns a
|
||||
// non-nil error.
|
||||
func EnsureClean(db *sql.DB) error {
|
||||
sourceDriver, m, err := migrateSetup(db)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("migrate setup: %w", err)
|
||||
}
|
||||
|
||||
version, dirty, err := m.Version()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get migration version: %w", err)
|
||||
}
|
||||
|
||||
if dirty {
|
||||
return xerrors.Errorf("database has not been cleanly migrated")
|
||||
}
|
||||
|
||||
// Verify that the database's migration version is "current" by checking
|
||||
// that a migration with that version exists, but there is no next version.
|
||||
err = CheckLatestVersion(sourceDriver, version)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("database needs migration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns nil if currentVersion corresponds to the latest available migration,
|
||||
// otherwise an error explaining why not.
|
||||
func CheckLatestVersion(sourceDriver source.Driver, currentVersion uint) error {
|
||||
// This is ugly, but seems like the only way to do it with the public
|
||||
// interfaces provided by golang-migrate.
|
||||
|
||||
// Check that there is no later version
|
||||
nextVersion, err := sourceDriver.Next(currentVersion)
|
||||
if err == nil {
|
||||
return xerrors.Errorf("current version is %d, but later version %d exists", currentVersion, nextVersion)
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return xerrors.Errorf("get next migration after %d: %w", currentVersion, err)
|
||||
}
|
||||
|
||||
// Once we reach this point, we know that either currentVersion doesn't
|
||||
// exist, or it has no successor (the return value from
|
||||
// sourceDriver.Next() is the same in either case). So we need to check
|
||||
// that either it's the first version, or it has a predecessor.
|
||||
|
||||
firstVersion, err := sourceDriver.First()
|
||||
if err != nil {
|
||||
// the total number of migrations should be non-zero, so this must be
|
||||
// an actual error, not just a missing file
|
||||
return xerrors.Errorf("get first migration: %w", err)
|
||||
}
|
||||
if firstVersion == currentVersion {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = sourceDriver.Prev(currentVersion)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get previous migration: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -4,8 +4,11 @@ package database_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/source"
|
||||
"github.com/golang-migrate/migrate/v4/source/stub"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
@ -75,3 +78,54 @@ func testSQLDB(t testing.TB) *sql.DB {
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// paralleltest linter doesn't correctly handle table-driven tests (https://github.com/kunwardeep/paralleltest/issues/8)
|
||||
// nolint:paralleltest
|
||||
func TestCheckLatestVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type test struct {
|
||||
currentVersion uint
|
||||
existingVersions []uint
|
||||
expectedResult string
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
// successful cases
|
||||
{1, []uint{1}, ""},
|
||||
{3, []uint{1, 2, 3}, ""},
|
||||
{3, []uint{1, 3}, ""},
|
||||
|
||||
// failure cases
|
||||
{1, []uint{1, 2}, "current version is 1, but later version 2 exists"},
|
||||
{2, []uint{1, 2, 3}, "current version is 2, but later version 3 exists"},
|
||||
{4, []uint{1, 2, 3}, "get previous migration: prev for version 4 : file does not exist"},
|
||||
{4, []uint{1, 2, 3, 5}, "get previous migration: prev for version 4 : file does not exist"},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
i, tc := i, tc
|
||||
t.Run(fmt.Sprintf("entry %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
driver, _ := stub.WithInstance(nil, &stub.Config{})
|
||||
stub, ok := driver.(*stub.Stub)
|
||||
require.True(t, ok)
|
||||
for _, version := range tc.existingVersions {
|
||||
stub.Migrations.Append(&source.Migration{
|
||||
Version: version,
|
||||
Identifier: "",
|
||||
Direction: source.Up,
|
||||
Raw: "",
|
||||
})
|
||||
}
|
||||
|
||||
err := database.CheckLatestVersion(driver, tc.currentVersion)
|
||||
var errMessage string
|
||||
if err != nil {
|
||||
errMessage = err.Error()
|
||||
}
|
||||
require.Equal(t, tc.expectedResult, errMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user