Files
coder/coderd/database/migrations/migrate_test.go

132 lines
2.7 KiB
Go

//go:build linux
package migrations_test
import (
"database/sql"
"fmt"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/stub"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/coder/coder/coderd/database/migrations"
"github.com/coder/coder/coderd/database/postgres"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestMigrate(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip()
return
}
t.Run("Once", func(t *testing.T) {
t.Parallel()
db := testSQLDB(t)
err := migrations.Up(db)
require.NoError(t, err)
})
t.Run("Twice", func(t *testing.T) {
t.Parallel()
db := testSQLDB(t)
err := migrations.Up(db)
require.NoError(t, err)
err = migrations.Up(db)
require.NoError(t, err)
})
t.Run("UpDownUp", func(t *testing.T) {
t.Parallel()
db := testSQLDB(t)
err := migrations.Up(db)
require.NoError(t, err)
err = migrations.Down(db)
require.NoError(t, err)
err = migrations.Up(db)
require.NoError(t, err)
})
}
func testSQLDB(t testing.TB) *sql.DB {
t.Helper()
connection, closeFn, err := postgres.Open()
require.NoError(t, err)
t.Cleanup(closeFn)
db, err := sql.Open("postgres", connection)
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
return db
}
// paralleltest linter doesn't correctly handle table-driven tests (https://github.com/kunwardeep/paralleltest/issues/8)
// nolint:paralleltest
func TestCheckLatestVersion(t *testing.T) {
t.Parallel()
type test struct {
currentVersion uint
existingVersions []uint
expectedResult string
}
tests := []test{
// successful cases
{1, []uint{1}, ""},
{3, []uint{1, 2, 3}, ""},
{3, []uint{1, 3}, ""},
// failure cases
{1, []uint{1, 2}, "current version is 1, but later version 2 exists"},
{2, []uint{1, 2, 3}, "current version is 2, but later version 3 exists"},
{4, []uint{1, 2, 3}, "get previous migration: prev for version 4 : file does not exist"},
{4, []uint{1, 2, 3, 5}, "get previous migration: prev for version 4 : file does not exist"},
}
for i, tc := range tests {
i, tc := i, tc
t.Run(fmt.Sprintf("entry %d", i), func(t *testing.T) {
t.Parallel()
driver, _ := stub.WithInstance(nil, &stub.Config{})
stub, ok := driver.(*stub.Stub)
require.True(t, ok)
for _, version := range tc.existingVersions {
stub.Migrations.Append(&source.Migration{
Version: version,
Identifier: "",
Direction: source.Up,
Raw: "",
})
}
err := migrations.CheckLatestVersion(driver, tc.currentVersion)
var errMessage string
if err != nil {
errMessage = err.Error()
}
require.Equal(t, tc.expectedResult, errMessage)
})
}
}