mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
chore: prevent db migrations from running on all cli commands (#15980)
This commit is contained in:
@ -3,22 +3,27 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
)
|
||||
|
||||
func (*RootCmd) resetPassword() *serpent.Command {
|
||||
var postgresURL string
|
||||
var (
|
||||
postgresURL string
|
||||
postgresAuth string
|
||||
)
|
||||
|
||||
root := &serpent.Command{
|
||||
Use: "reset-password <username>",
|
||||
@ -27,20 +32,26 @@ func (*RootCmd) resetPassword() *serpent.Command {
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
username := inv.Args[0]
|
||||
|
||||
sqlDB, err := sql.Open("postgres", postgresURL)
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stdout))
|
||||
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
|
||||
sqlDriver := "postgres"
|
||||
if codersdk.PostgresAuth(postgresAuth) == codersdk.PostgresAuthAWSIAMRDS {
|
||||
var err error
|
||||
sqlDriver, err = awsiamrds.Register(inv.Context(), sqlDriver)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("register aws rds iam auth: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := ConnectToPostgres(inv.Context(), logger, sqlDriver, postgresURL, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial postgres: %w", err)
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
err = sqlDB.Ping()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
|
||||
err = migrations.EnsureClean(sqlDB)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("database needs migration: %w", err)
|
||||
}
|
||||
db := database.New(sqlDB)
|
||||
|
||||
user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{
|
||||
@ -97,6 +108,14 @@ func (*RootCmd) resetPassword() *serpent.Command {
|
||||
Env: "CODER_PG_CONNECTION_URL",
|
||||
Value: serpent.StringOf(&postgresURL),
|
||||
},
|
||||
serpent.Option{
|
||||
Name: "Postgres Connection Auth",
|
||||
Description: "Type of auth to use when connecting to postgres.",
|
||||
Flag: "postgres-connection-auth",
|
||||
Env: "CODER_PG_CONNECTION_AUTH",
|
||||
Default: "password",
|
||||
Value: serpent.EnumOf(&postgresAuth, codersdk.PostgresAuthDrivers...),
|
||||
},
|
||||
}
|
||||
|
||||
return root
|
||||
|
@ -697,7 +697,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
options.Database = dbmem.New()
|
||||
options.Pubsub = pubsub.NewInMemory()
|
||||
} else {
|
||||
sqlDB, dbURL, err := getPostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
|
||||
sqlDB, dbURL, err := getAndMigratePostgresDB(ctx, logger, vals.PostgresURL.String(), codersdk.PostgresAuth(vals.PostgresAuth), sqlDriver)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
@ -2090,9 +2090,18 @@ func IsLocalhost(host string) bool {
|
||||
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
||||
}
|
||||
|
||||
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
|
||||
// ConnectToPostgres takes in the migration command to run on the database once
|
||||
// it connects. To avoid running migrations, pass in `nil` or a no-op function.
|
||||
// Regardless of the passed in migration function, if the database is not fully
|
||||
// migrated, an error will be returned. This can happen if the database is on a
|
||||
// future or past migration version.
|
||||
//
|
||||
// If no error is returned, the database is fully migrated and up to date.
|
||||
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string, migrate func(db *sql.DB) error) (*sql.DB, error) {
|
||||
logger.Debug(ctx, "connecting to postgresql")
|
||||
|
||||
var err error
|
||||
var sqlDB *sql.DB
|
||||
// Try to connect for 30 seconds.
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
@ -2155,10 +2164,17 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
|
||||
}
|
||||
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
|
||||
|
||||
err = migrations.Up(sqlDB)
|
||||
if migrate != nil {
|
||||
err = migrate(sqlDB)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("migrate up: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = migrations.EnsureClean(sqlDB)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("migrations in database: %w", err)
|
||||
}
|
||||
// The default is 0 but the request will fail with a 500 if the DB
|
||||
// cannot accept new connections, so we try to limit that here.
|
||||
// Requests will wait for a new connection instead of a hard error
|
||||
@ -2561,7 +2577,7 @@ func signalNotifyContext(ctx context.Context, inv *serpent.Invocation, sig ...os
|
||||
return inv.SignalNotifyContext(ctx, sig...)
|
||||
}
|
||||
|
||||
func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
|
||||
func getAndMigratePostgresDB(ctx context.Context, logger slog.Logger, postgresURL string, auth codersdk.PostgresAuth, sqlDriver string) (*sql.DB, string, error) {
|
||||
dbURL, err := escapePostgresURLUserInfo(postgresURL)
|
||||
if err != nil {
|
||||
return nil, "", xerrors.Errorf("escaping postgres URL: %w", err)
|
||||
@ -2574,7 +2590,7 @@ func getPostgresDB(ctx context.Context, logger slog.Logger, postgresURL string,
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL)
|
||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, dbURL, migrations.Up)
|
||||
if err != nil {
|
||||
return nil, "", xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL)
|
||||
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, newUserDBURL, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
|
@ -38,11 +38,13 @@ import (
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/cli"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/cli/config"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@ -1828,6 +1830,10 @@ func TestConnectToPostgres(t *testing.T) {
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test does not make sense without postgres")
|
||||
}
|
||||
|
||||
t.Run("Migrate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
@ -1836,12 +1842,39 @@ func TestConnectToPostgres(t *testing.T) {
|
||||
dbURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL)
|
||||
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = sqlDB.Close()
|
||||
})
|
||||
require.NoError(t, sqlDB.PingContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("NoMigrate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
dbURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
|
||||
require.NoError(t, err)
|
||||
defer okDB.Close()
|
||||
|
||||
// Set the migration number forward
|
||||
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "database needs migration")
|
||||
|
||||
require.NoError(t, okDB.PingContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_InvalidDERP(t *testing.T) {
|
||||
|
@ -6,6 +6,9 @@ USAGE:
|
||||
Directly connect to the database to reset a user's password
|
||||
|
||||
OPTIONS:
|
||||
--postgres-connection-auth password|awsiamrds, $CODER_PG_CONNECTION_AUTH (default: password)
|
||||
Type of auth to use when connecting to postgres.
|
||||
|
||||
--postgres-url string, $CODER_PG_CONNECTION_URL
|
||||
URL of a PostgreSQL database to connect to.
|
||||
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli"
|
||||
"github.com/coder/coder/v2/coderd/database/awsiamrds"
|
||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@ -32,7 +33,7 @@ func TestDriver(t *testing.T) {
|
||||
sqlDriver, err := awsiamrds.Register(ctx, "postgres")
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url)
|
||||
db, err := cli.ConnectToPostgres(ctx, testutil.Logger(t), sqlDriver, url, migrations.Up)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
|
10
docs/reference/cli/reset-password.md
generated
10
docs/reference/cli/reset-password.md
generated
@ -19,3 +19,13 @@ coder reset-password [flags] <username>
|
||||
| Environment | <code>$CODER_PG_CONNECTION_URL</code> |
|
||||
|
||||
URL of a PostgreSQL database to connect to.
|
||||
|
||||
### --postgres-connection-auth
|
||||
|
||||
| | |
|
||||
|-------------|----------------------------------------|
|
||||
| Type | <code>password\|awsiamrds</code> |
|
||||
| Environment | <code>$CODER_PG_CONNECTION_AUTH</code> |
|
||||
| Default | <code>password</code> |
|
||||
|
||||
Type of auth to use when connecting to postgres.
|
||||
|
@ -98,7 +98,7 @@ func (*RootCmd) dbcryptRotateCmd() *serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
|
||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
@ -163,7 +163,7 @@ func (*RootCmd) dbcryptDecryptCmd() *serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
|
||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
@ -219,7 +219,7 @@ Are you sure you want to continue?`
|
||||
}
|
||||
}
|
||||
|
||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL)
|
||||
sqlDB, err := cli.ConnectToPostgres(inv.Context(), logger, sqlDriver, flags.PostgresURL, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to postgres: %w", err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user