chore: prevent db migrations from running on all cli commands (#15980)

This commit is contained in:
Steven Masley
2025-01-03 12:15:35 -05:00
committed by GitHub
parent 813270d63a
commit a7ed977ba9
8 changed files with 115 additions and 33 deletions

View File

@ -3,22 +3,27 @@
package cli
import (
"database/sql"
"fmt"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/coderd/database/awsiamrds"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/userpassword"
)
func (*RootCmd) resetPassword() *serpent.Command {
var postgresURL string
var (
postgresURL string
postgresAuth string
)
root := &serpent.Command{
Use: "reset-password <username>",
@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command {
Handler: func(inv *serpent.Invocation) error {
username := inv.Args[0]
sqlDB, err := sql.Open("postgres", postgresURL)
logger := slog.Make(sloghuman.Sink(inv.Stdout))
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
logger = logger.Leveled(slog.LevelDebug)
}
sqlDriver := "postgres"
if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS {
var err error
sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver)
if err != nil {
return xerrors.Errorf("register aws rds iam auth: %w", err)
}
}
sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil)
if err != nil {
return xerrors.Errorf("dial postgres: %w", err)
}
defer sqlDB.Close()
err = sqlDB.Ping()
if err != nil {
return xerrors.Errorf("ping postgres: %w", err)
}
err = migrations.EnsureClean(sqlDB)
if err != nil {
return xerrors.Errorf("database needs migration: %w", err)
}
db := database.New(sqlDB)
user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command {
Env: "CODER_PG_CONNECTION_URL",
Value: serpent.StringOf(&postgresURL),
},
serpent.Option{
Name: "Postgres Connection Auth",
Description: "Type of auth to use when connecting to postgres.",
Flag: "postgres-connection-auth",
Env: "CODER_PG_CONNECTION_AUTH",
Default: "password",
Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...),
},
}
return root

View File

@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
options.Database = dbmem.New()
options.Pubsub = pubsub.NewInMemory()
} else {
sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool {
return host == "localhost" || host == "127.0.0.1" || host == "::1"
}
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
// ConnectToPostgres takes in the migration command to run on the database once
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
// Regardless of the passed in migration function, if the database is not fully
// migrated, an error will be returned. This can happen if the database is on a
// future or past migration version.
//
// If no error is returned, the database is fully migrated and up to date.
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
logger.Debug(ctx, "connecting to postgresql")
var err error
var sqlDB *sql.DB
// Try to connect for 30 seconds.
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
@ -2155,10 +2164,17 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
}
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
err = migrations.Up(sqlDB)
if migrate != nil {
err = migrate(sqlDB)
if err != nil {
return nil, xerrors.Errorf("migrate up: %w", err)
}
}
err = migrations.EnsureClean(sqlDB)
if err != nil {
return nil, xerrors.Errorf("migrations in database: %w", err)
}
// The default is 0 but the request will fail with a 500 if the DB
// cannot accept new connections, so we try to limit that here.
// Requests will wait for a new connection instead of a hard error
@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os
return inv.SignalNotifyContext(ctx, sig...)
}
func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
dbURL, err := escapePostgresURLUserInfo(postgresURL)
if err != nil {
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string,
}
}
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL)
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
if err != nil {
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
}

View File

@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
}
}
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL)
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}

View File

@ -38,11 +38,13 @@ import (
"tailscale.com/derp/derphttp"
"tailscale.com/types/key"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/codersdk"
@ -1828,6 +1830,10 @@ func TestConnectToPostgres(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip("this test does not make sense without postgres")
}
t.Run("Migrate", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)
@ -1836,12 +1842,39 @@ func TestConnectToPostgres(t *testing.T) {
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL)
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
require.NoError(t, err)
t.Cleanup(func() {
_ = sqlDB.Close()
})
require.NoError(t, sqlDB.PingContext(ctx))
})
t.Run("NoMigrate", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
require.NoError(t, err)
defer okDB.Close()
// Set the migration number forward
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
require.NoError(t, err)
_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
require.Error(t, err)
require.ErrorContains(t, err, "database needs migration")
require.NoError(t, okDB.PingContext(ctx))
})
}
func TestServer_InvalidDERP(t *testing.T) {

View File

@ -6,6 +6,9 @@ USAGE:
Directly connect to the database to reset a user's password
OPTIONS:
--postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password)
Type of auth to use when connecting to postgres.
--postgres-url string, $CODER_PG_CONNECTION_URL
URL of a PostgreSQL database to connect to.

View File

@ -9,6 +9,7 @@ import (
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/coderd/database/awsiamrds"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/testutil"
)
@ -32,7 +33,7 @@ func TestDriver(t *testing.T) {
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
require.NoError(t, err)
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url)
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up)
require.NoError(t, err)
defer func() {
_ = db.Close()

View File

@ -19,3 +19,13 @@ coder reset-password [flags] <username>
| Environment | <code>$CODER_PG_CONNECTION_URL</code> |
URL of a PostgreSQL database to connect to.
### --postgres-connection-auth
| | |
|-------------|----------------------------------------|
| Type | <code>password\|awsiamrds</code> |
| Environment | <code>$CODER_PG_CONNECTION_AUTH</code> |
| Default | <code>password</code> |
Type of auth to use when connecting to postgres.

View File

@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command {
}
}
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command {
}
}
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}
@ -219,7 +219,7 @@ Are you sure you want to continue?`
}
}
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
if err != nil {
return xerrors.Errorf("connect to postgres: %w", err)
}