diff --git a/CHANGELOG.md b/CHANGELOG.md index e48b9be3..1671f8a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Basic stuck detection after a job's exceeded its timeout and still not returned after the executor's initiated context cancellation and waited a short margin for the cancellation to take effect. [PR #1097](https://github.com/riverqueue/river/pull/1097). - Added `Client.JobUpdate` which can be used to persist job output partway through a running work function instead of having to wait until the job is completed. [PR #1098](https://github.com/riverqueue/river/pull/1098). - Add a little more error flavor for when encountering a deadline exceeded error on leadership election suggesting that the user may want to try increasing their database pool size. [PR #1101](https://github.com/riverqueue/river/pull/1101). +- When migrating without an outer transaction, insert/delete version rows immediately after executing migration SQL so that in case a later migration fails, the migrator knows where to restart from. [PR #1106](https://github.com/riverqueue/river/pull/1106). ## [0.29.0-rc.1] - 2025-12-04 diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go index 1c28ed7e..99068256 100644 --- a/rivermigrate/river_migrate.go +++ b/rivermigrate/river_migrate.go @@ -26,10 +26,17 @@ import ( "github.com/riverqueue/river/rivershared/util/sliceutil" ) -// The migrate version where the `line` column was added. Meaningful in that the -// migrator has to behave a little differently depending on whether it's working -// with versions before or after this boundary. -const migrateVersionLineColumnAdded = 5 +const ( + // The migrate version where the `line` column was added. Meaningful in that + // the migrator has to behave a little differently depending on whether it's + // working with versions before or after this boundary. + migrateVersionLineColumnAdded = 5 + + // The migration version where the `river_migration` table is added. This is + // used for one special case where we don't try to delete a version record + // after downmigrating version 1. + migrateVersionTableAdded = 1 +) // Migration is a bundled migration containing a version (e.g. 1, 2, 3), and SQL // for up and down directions. @@ -308,9 +315,9 @@ func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts * exec := m.driver.GetExecutor() switch direction { case DirectionDown: - return m.migrateDown(ctx, exec, direction, opts) + return m.migrateDown(ctx, exec, direction, opts, false) case DirectionUp: - return m.migrateUp(ctx, exec, direction, opts) + return m.migrateUp(ctx, exec, direction, opts, false) } panic("invalid direction: " + direction) @@ -340,9 +347,9 @@ func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts * func (m *Migrator[TTx]) MigrateTx(ctx context.Context, tx TTx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { switch direction { case DirectionDown: - return m.migrateDown(ctx, m.driver.UnwrapExecutor(tx), direction, opts) + return m.migrateDown(ctx, m.driver.UnwrapExecutor(tx), direction, opts, true) case DirectionUp: - return m.migrateUp(ctx, m.driver.UnwrapExecutor(tx), direction, opts) + return m.migrateUp(ctx, m.driver.UnwrapExecutor(tx), direction, opts, true) } panic("invalid direction: " + direction) @@ -377,7 +384,7 @@ func (m *Migrator[TTx]) ValidateTx(ctx context.Context, tx TTx) (*ValidateResult } // migrateDown runs down migrations. -func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { +func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, inOuterTx bool) (*MigrateResult, error) { existingMigrations, err := m.existingMigrations(ctx, exec) if err != nil { return nil, err @@ -395,7 +402,7 @@ func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Execut sortedTargetMigrations := maputil.Values(targetMigrations) slices.SortFunc(sortedTargetMigrations, func(a, b Migration) int { return b.Version - a.Version }) // reverse order - res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) + res, err := m.applyMigrations(ctx, exec, direction, opts, inOuterTx, sortedTargetMigrations) if err != nil { return nil, err } @@ -414,28 +421,11 @@ func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Execut return res, nil } - if !opts.DryRun && len(res.Versions) > 0 { - versions := sliceutil.Map(res.Versions, migrateVersionToInt) - - // Version 005 is hard-coded here because that's the version in which - // the migration `line` comes in. If migration to a point equal or above - // 005, we can remove migrations with a line included, but otherwise we - // must omit the `line` column from queries because it doesn't exist. - if m.line == riverdriver.MigrationLineMain && slices.Min(versions) <= migrateVersionLineColumnAdded { - if _, err := exec.MigrationDeleteAssumingMainMany(ctx, &riverdriver.MigrationDeleteAssumingMainManyParams{ - Versions: versions, - Schema: m.schema, - }); err != nil { - return nil, fmt.Errorf("error inserting migration rows for versions %+v assuming main: %w", res.Versions, err) - } - } else { - if _, err := exec.MigrationDeleteByLineAndVersionMany(ctx, &riverdriver.MigrationDeleteByLineAndVersionManyParams{ - Line: m.line, - Schema: m.schema, - Versions: versions, - }); err != nil { - return nil, fmt.Errorf("error deleting migration rows for versions %+v on line %q: %w", res.Versions, m.line, err) - } + // When operating with an outer transaction, all versions are removed at + // once so we can save a few database operations. + if inOuterTx { + if err := m.versionsDelete(ctx, exec, opts, sliceutil.Map(res.Versions, migrateVersionToInt)...); err != nil { + return nil, err } } @@ -443,7 +433,7 @@ func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Execut } // migrateUp runs up migrations. -func (m *Migrator[TTx]) migrateUp(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { +func (m *Migrator[TTx]) migrateUp(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, inOuterTx bool) (*MigrateResult, error) { existingMigrations, err := m.existingMigrations(ctx, exec) if err != nil { return nil, err @@ -457,33 +447,16 @@ func (m *Migrator[TTx]) migrateUp(ctx context.Context, exec riverdriver.Executor sortedTargetMigrations := maputil.Values(targetMigrations) slices.SortFunc(sortedTargetMigrations, func(a, b Migration) int { return a.Version - b.Version }) - res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) + res, err := m.applyMigrations(ctx, exec, direction, opts, inOuterTx, sortedTargetMigrations) if err != nil { return nil, err } - if (opts == nil || !opts.DryRun) && len(res.Versions) > 0 { - versions := sliceutil.Map(res.Versions, migrateVersionToInt) - - // Version 005 is hard-coded here because that's the version in which - // the migration `line` comes in. If migration to a point equal or above - // 005, we can insert migrations with a line included, but otherwise we - // must omit the `line` column from queries because it doesn't exist. - if m.line == riverdriver.MigrationLineMain && slices.Max(versions) < migrateVersionLineColumnAdded { - if _, err := exec.MigrationInsertManyAssumingMain(ctx, &riverdriver.MigrationInsertManyAssumingMainParams{ - Schema: m.schema, - Versions: versions, - }); err != nil { - return nil, fmt.Errorf("error inserting migration rows for versions %+v assuming main: %w", res.Versions, err) - } - } else { - if _, err := exec.MigrationInsertMany(ctx, &riverdriver.MigrationInsertManyParams{ - Line: m.line, - Schema: m.schema, - Versions: versions, - }); err != nil { - return nil, fmt.Errorf("error inserting migration rows for versions %+v on line %q: %w", res.Versions, m.line, err) - } + // When operating with an outer transaction, all versions are added at once + // so we can save a few database operations. + if inOuterTx { + if err := m.versionsInsert(ctx, exec, opts, sliceutil.Map(res.Versions, migrateVersionToInt)...); err != nil { + return nil, err } } @@ -519,7 +492,7 @@ func (m *Migrator[TTx]) validate(ctx context.Context, exec riverdriver.Executor) // Common code shared between the up and down migration directions that walks // through each target migration and applies it, logging appropriately. -func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, sortedTargetMigrations []Migration) (*MigrateResult, error) { +func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, inOuterTx bool, sortedTargetMigrations []Migration) (*MigrateResult, error) { if opts == nil { opts = &MigrateOpts{} } @@ -606,6 +579,23 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex return fmt.Errorf("error applying version %03d [%s]: %w", versionBundle.Version, strings.ToUpper(string(direction)), err) } + + // If operating without outer transaction, add/remove the + // migration version in the same transaction in which we + // executed the migration SQL. + if !inOuterTx { + switch direction { + case DirectionDown: + if err := m.versionsDelete(ctx, exec, opts, versionBundle.Version); err != nil { + return err + } + case DirectionUp: + if err := m.versionsInsert(ctx, exec, opts, versionBundle.Version); err != nil { + return err + } + } + } + return nil }) if err != nil { @@ -683,6 +673,70 @@ func (m *Migrator[TTx]) existingMigrations(ctx context.Context, exec riverdriver return migrations, nil } +func (m *Migrator[TTx]) versionsDelete(ctx context.Context, exec riverdriver.Executor, opts *MigrateOpts, versions ...int) error { + if opts.DryRun || len(versions) < 1 { + return nil + } + + // Don't try to remove anything if we're migrating back below version 1, + // where `river_migration` was added. + if len(versions) == 1 && versions[0] <= migrateVersionTableAdded { + return nil + } + + // Version 005 is hard-coded here because that's the version in which + // the migration `line` comes in. If migration to a point equal or above + // 005, we can remove migrations with a line included, but otherwise we + // must omit the `line` column from queries because it doesn't exist. + if m.line == riverdriver.MigrationLineMain && slices.Min(versions) <= migrateVersionLineColumnAdded { + if _, err := exec.MigrationDeleteAssumingMainMany(ctx, &riverdriver.MigrationDeleteAssumingMainManyParams{ + Versions: versions, + Schema: m.schema, + }); err != nil { + return fmt.Errorf("error inserting migration rows for versions %+v assuming main: %w", versions, err) + } + } else { + if _, err := exec.MigrationDeleteByLineAndVersionMany(ctx, &riverdriver.MigrationDeleteByLineAndVersionManyParams{ + Line: m.line, + Schema: m.schema, + Versions: versions, + }); err != nil { + return fmt.Errorf("error deleting migration rows for versions %+v on line %q: %w", versions, m.line, err) + } + } + + return nil +} + +func (m *Migrator[TTx]) versionsInsert(ctx context.Context, exec riverdriver.Executor, opts *MigrateOpts, versions ...int) error { + if opts.DryRun || len(versions) < 1 { + return nil + } + + // Version 005 is hard-coded here because that's the version in which + // the migration `line` comes in. If migration to a point equal or above + // 005, we can insert migrations with a line included, but otherwise we + // must omit the `line` column from queries because it doesn't exist. + if m.line == riverdriver.MigrationLineMain && slices.Max(versions) < migrateVersionLineColumnAdded { + if _, err := exec.MigrationInsertManyAssumingMain(ctx, &riverdriver.MigrationInsertManyAssumingMainParams{ + Schema: m.schema, + Versions: versions, + }); err != nil { + return fmt.Errorf("error inserting migration rows for versions %+v assuming main: %w", versions, err) + } + } else { + if _, err := exec.MigrationInsertMany(ctx, &riverdriver.MigrationInsertManyParams{ + Line: m.line, + Schema: m.schema, + Versions: versions, + }); err != nil { + return fmt.Errorf("error inserting migration rows for versions %+v on line %q: %w", versions, m.line, err) + } + } + + return nil +} + // Reads a series of migration bundles from a file system, which practically // speaking will always be the embedded FS read from the contents of the // `migration//` subdirectory. diff --git a/rivermigrate/river_migrate_test.go b/rivermigrate/river_migrate_test.go index 6ce56dd1..ef44f29c 100644 --- a/rivermigrate/river_migrate_test.go +++ b/rivermigrate/river_migrate_test.go @@ -392,6 +392,43 @@ func TestMigrator(t *testing.T) { sliceutil.Map(migrations, driverMigrationToInt)) }) + // Can't use riverdbtest inthis package due to a circular dependency problem. + testTx := func(t *testing.T, driver *driverWithAlternateLine) pgx.Tx { + t.Helper() + + execTx, err := driver.GetExecutor().Begin(ctx) + require.NoError(t, err) + + t.Cleanup(func() { require.NoError(t, execTx.Rollback(ctx)) }) + + return driver.UnwrapTx(execTx) + } + + t.Run("MigrateDownTx", func(t *testing.T) { + t.Parallel() + + // Some transactional incompatibilities were introduced into the + // migration lines so we can no longer exercise the *Tx functions all + // the way up and down right now. Only do a couple steps to give them a + // little exercise and in such a away that they're functional. + // still work + const maxSteps = 2 + + migrator, bundle := setup(t) + + tx := testTx(t, bundle.driver) + + _, err := migrator.MigrateTx(ctx, tx, DirectionUp, &MigrateOpts{ + MaxSteps: maxSteps, + }) + require.NoError(t, err) + + _, err = migrator.MigrateTx(ctx, tx, DirectionDown, &MigrateOpts{ + MaxSteps: maxSteps, + }) + require.NoError(t, err) + }) + t.Run("GetVersion", func(t *testing.T) { t.Parallel() @@ -581,6 +618,26 @@ func TestMigrator(t *testing.T) { sliceutil.Map(migrations, driverMigrationToInt)) }) + t.Run("MigrateUpTx", func(t *testing.T) { + t.Parallel() + + // Some transactional incompatibilities were introduced into the + // migration lines so we can no longer exercise the *Tx functions all + // the way up and down right now. Only do a couple steps to give them a + // little exercise and in such a away that they're functional. + // still work + const maxSteps = 2 + + migrator, bundle := setup(t) + + tx := testTx(t, bundle.driver) + + _, err := migrator.MigrateTx(ctx, tx, DirectionUp, &MigrateOpts{ + MaxSteps: maxSteps, + }) + require.NoError(t, err) + }) + t.Run("ValidateSuccess", func(t *testing.T) { t.Parallel()