@@ -26,10 +26,17 @@ import (
2626 "github.com/riverqueue/river/rivershared/util/sliceutil"
2727)
2828
29- // The migrate version where the `line` column was added. Meaningful in that the
30- // migrator has to behave a little differently depending on whether it's working
31- // with versions before or after this boundary.
32- const migrateVersionLineColumnAdded = 5
29+ const (
30+ // The migrate version where the `line` column was added. Meaningful in that
31+ // the migrator has to behave a little differently depending on whether it's
32+ // working with versions before or after this boundary.
33+ migrateVersionLineColumnAdded = 5
34+
35+ // The migration version where the `river_migration` table is added. This is
36+ // used for one special case where we don't try to delete a version record
37+ // after downmigrating version 1.
38+ migrateVersionTableAdded = 1
39+ )
3340
3441// Migration is a bundled migration containing a version (e.g. 1, 2, 3), and SQL
3542// for up and down directions.
@@ -308,9 +315,9 @@ func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts *
308315 exec := m .driver .GetExecutor ()
309316 switch direction {
310317 case DirectionDown :
311- return m .migrateDown (ctx , exec , direction , opts )
318+ return m .migrateDown (ctx , exec , direction , opts , false )
312319 case DirectionUp :
313- return m .migrateUp (ctx , exec , direction , opts )
320+ return m .migrateUp (ctx , exec , direction , opts , false )
314321 }
315322
316323 panic ("invalid direction: " + direction )
@@ -340,9 +347,9 @@ func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts *
340347func (m * Migrator [TTx ]) MigrateTx (ctx context.Context , tx TTx , direction Direction , opts * MigrateOpts ) (* MigrateResult , error ) {
341348 switch direction {
342349 case DirectionDown :
343- return m .migrateDown (ctx , m .driver .UnwrapExecutor (tx ), direction , opts )
350+ return m .migrateDown (ctx , m .driver .UnwrapExecutor (tx ), direction , opts , true )
344351 case DirectionUp :
345- return m .migrateUp (ctx , m .driver .UnwrapExecutor (tx ), direction , opts )
352+ return m .migrateUp (ctx , m .driver .UnwrapExecutor (tx ), direction , opts , true )
346353 }
347354
348355 panic ("invalid direction: " + direction )
@@ -377,7 +384,7 @@ func (m *Migrator[TTx]) ValidateTx(ctx context.Context, tx TTx) (*ValidateResult
377384}
378385
379386// migrateDown runs down migrations.
380- func (m * Migrator [TTx ]) migrateDown (ctx context.Context , exec riverdriver.Executor , direction Direction , opts * MigrateOpts ) (* MigrateResult , error ) {
387+ func (m * Migrator [TTx ]) migrateDown (ctx context.Context , exec riverdriver.Executor , direction Direction , opts * MigrateOpts , inOuterTx bool ) (* MigrateResult , error ) {
381388 existingMigrations , err := m .existingMigrations (ctx , exec )
382389 if err != nil {
383390 return nil , err
@@ -395,7 +402,7 @@ func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Execut
395402 sortedTargetMigrations := maputil .Values (targetMigrations )
396403 slices .SortFunc (sortedTargetMigrations , func (a , b Migration ) int { return b .Version - a .Version }) // reverse order
397404
398- res , err := m .applyMigrations (ctx , exec , direction , opts , sortedTargetMigrations )
405+ res , err := m .applyMigrations (ctx , exec , direction , opts , inOuterTx , sortedTargetMigrations )
399406 if err != nil {
400407 return nil , err
401408 }
@@ -414,36 +421,19 @@ func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Execut
414421 return res , nil
415422 }
416423
417- if ! opts .DryRun && len (res .Versions ) > 0 {
418- versions := sliceutil .Map (res .Versions , migrateVersionToInt )
419-
420- // Version 005 is hard-coded here because that's the version in which
421- // the migration `line` comes in. If migration to a point equal or above
422- // 005, we can remove migrations with a line included, but otherwise we
423- // must omit the `line` column from queries because it doesn't exist.
424- if m .line == riverdriver .MigrationLineMain && slices .Min (versions ) <= migrateVersionLineColumnAdded {
425- if _ , err := exec .MigrationDeleteAssumingMainMany (ctx , & riverdriver.MigrationDeleteAssumingMainManyParams {
426- Versions : versions ,
427- Schema : m .schema ,
428- }); err != nil {
429- return nil , fmt .Errorf ("error inserting migration rows for versions %+v assuming main: %w" , res .Versions , err )
430- }
431- } else {
432- if _ , err := exec .MigrationDeleteByLineAndVersionMany (ctx , & riverdriver.MigrationDeleteByLineAndVersionManyParams {
433- Line : m .line ,
434- Schema : m .schema ,
435- Versions : versions ,
436- }); err != nil {
437- return nil , fmt .Errorf ("error deleting migration rows for versions %+v on line %q: %w" , res .Versions , m .line , err )
438- }
424+ // When operating with an outer transaction, all versions are removed at
425+ // once so we can save a few database operations.
426+ if inOuterTx {
427+ if err := m .versionsDelete (ctx , exec , opts , sliceutil .Map (res .Versions , migrateVersionToInt )... ); err != nil {
428+ return nil , err
439429 }
440430 }
441431
442432 return res , nil
443433}
444434
445435// migrateUp runs up migrations.
446- func (m * Migrator [TTx ]) migrateUp (ctx context.Context , exec riverdriver.Executor , direction Direction , opts * MigrateOpts ) (* MigrateResult , error ) {
436+ func (m * Migrator [TTx ]) migrateUp (ctx context.Context , exec riverdriver.Executor , direction Direction , opts * MigrateOpts , inOuterTx bool ) (* MigrateResult , error ) {
447437 existingMigrations , err := m .existingMigrations (ctx , exec )
448438 if err != nil {
449439 return nil , err
@@ -457,33 +447,16 @@ func (m *Migrator[TTx]) migrateUp(ctx context.Context, exec riverdriver.Executor
457447 sortedTargetMigrations := maputil .Values (targetMigrations )
458448 slices .SortFunc (sortedTargetMigrations , func (a , b Migration ) int { return a .Version - b .Version })
459449
460- res , err := m .applyMigrations (ctx , exec , direction , opts , sortedTargetMigrations )
450+ res , err := m .applyMigrations (ctx , exec , direction , opts , inOuterTx , sortedTargetMigrations )
461451 if err != nil {
462452 return nil , err
463453 }
464454
465- if (opts == nil || ! opts .DryRun ) && len (res .Versions ) > 0 {
466- versions := sliceutil .Map (res .Versions , migrateVersionToInt )
467-
468- // Version 005 is hard-coded here because that's the version in which
469- // the migration `line` comes in. If migration to a point equal or above
470- // 005, we can insert migrations with a line included, but otherwise we
471- // must omit the `line` column from queries because it doesn't exist.
472- if m .line == riverdriver .MigrationLineMain && slices .Max (versions ) < migrateVersionLineColumnAdded {
473- if _ , err := exec .MigrationInsertManyAssumingMain (ctx , & riverdriver.MigrationInsertManyAssumingMainParams {
474- Schema : m .schema ,
475- Versions : versions ,
476- }); err != nil {
477- return nil , fmt .Errorf ("error inserting migration rows for versions %+v assuming main: %w" , res .Versions , err )
478- }
479- } else {
480- if _ , err := exec .MigrationInsertMany (ctx , & riverdriver.MigrationInsertManyParams {
481- Line : m .line ,
482- Schema : m .schema ,
483- Versions : versions ,
484- }); err != nil {
485- return nil , fmt .Errorf ("error inserting migration rows for versions %+v on line %q: %w" , res .Versions , m .line , err )
486- }
455+ // When operating with an outer transaction, all versions are added at once
456+ // so we can save a few database operations.
457+ if inOuterTx {
458+ if err := m .versionsInsert (ctx , exec , opts , sliceutil .Map (res .Versions , migrateVersionToInt )... ); err != nil {
459+ return nil , err
487460 }
488461 }
489462
@@ -519,7 +492,7 @@ func (m *Migrator[TTx]) validate(ctx context.Context, exec riverdriver.Executor)
519492
520493// Common code shared between the up and down migration directions that walks
521494// through each target migration and applies it, logging appropriately.
522- func (m * Migrator [TTx ]) applyMigrations (ctx context.Context , exec riverdriver.Executor , direction Direction , opts * MigrateOpts , sortedTargetMigrations []Migration ) (* MigrateResult , error ) {
495+ func (m * Migrator [TTx ]) applyMigrations (ctx context.Context , exec riverdriver.Executor , direction Direction , opts * MigrateOpts , inOuterTx bool , sortedTargetMigrations []Migration ) (* MigrateResult , error ) {
523496 if opts == nil {
524497 opts = & MigrateOpts {}
525498 }
@@ -606,6 +579,23 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex
606579 return fmt .Errorf ("error applying version %03d [%s]: %w" ,
607580 versionBundle .Version , strings .ToUpper (string (direction )), err )
608581 }
582+
583+ // If operating without outer transaction, add/remove the
584+ // migration version in the same transaction in which we
585+ // executed the migration SQL.
586+ if ! inOuterTx {
587+ switch direction {
588+ case DirectionDown :
589+ if err := m .versionsDelete (ctx , exec , opts , versionBundle .Version ); err != nil {
590+ return err
591+ }
592+ case DirectionUp :
593+ if err := m .versionsInsert (ctx , exec , opts , versionBundle .Version ); err != nil {
594+ return err
595+ }
596+ }
597+ }
598+
609599 return nil
610600 })
611601 if err != nil {
@@ -683,6 +673,70 @@ func (m *Migrator[TTx]) existingMigrations(ctx context.Context, exec riverdriver
683673 return migrations , nil
684674}
685675
676+ func (m * Migrator [TTx ]) versionsDelete (ctx context.Context , exec riverdriver.Executor , opts * MigrateOpts , versions ... int ) error {
677+ if opts .DryRun || len (versions ) < 1 {
678+ return nil
679+ }
680+
681+ // Don't try to remove anything if we're migrating back below version 1,
682+ // where `river_migration` was added.
683+ if len (versions ) == 1 && versions [0 ] <= migrateVersionTableAdded {
684+ return nil
685+ }
686+
687+ // Version 005 is hard-coded here because that's the version in which
688+ // the migration `line` comes in. If migration to a point equal or above
689+ // 005, we can remove migrations with a line included, but otherwise we
690+ // must omit the `line` column from queries because it doesn't exist.
691+ if m .line == riverdriver .MigrationLineMain && slices .Min (versions ) <= migrateVersionLineColumnAdded {
692+ if _ , err := exec .MigrationDeleteAssumingMainMany (ctx , & riverdriver.MigrationDeleteAssumingMainManyParams {
693+ Versions : versions ,
694+ Schema : m .schema ,
695+ }); err != nil {
696+ return fmt .Errorf ("error inserting migration rows for versions %+v assuming main: %w" , versions , err )
697+ }
698+ } else {
699+ if _ , err := exec .MigrationDeleteByLineAndVersionMany (ctx , & riverdriver.MigrationDeleteByLineAndVersionManyParams {
700+ Line : m .line ,
701+ Schema : m .schema ,
702+ Versions : versions ,
703+ }); err != nil {
704+ return fmt .Errorf ("error deleting migration rows for versions %+v on line %q: %w" , versions , m .line , err )
705+ }
706+ }
707+
708+ return nil
709+ }
710+
711+ func (m * Migrator [TTx ]) versionsInsert (ctx context.Context , exec riverdriver.Executor , opts * MigrateOpts , versions ... int ) error {
712+ if opts .DryRun || len (versions ) < 1 {
713+ return nil
714+ }
715+
716+ // Version 005 is hard-coded here because that's the version in which
717+ // the migration `line` comes in. If migration to a point equal or above
718+ // 005, we can insert migrations with a line included, but otherwise we
719+ // must omit the `line` column from queries because it doesn't exist.
720+ if m .line == riverdriver .MigrationLineMain && slices .Max (versions ) < migrateVersionLineColumnAdded {
721+ if _ , err := exec .MigrationInsertManyAssumingMain (ctx , & riverdriver.MigrationInsertManyAssumingMainParams {
722+ Schema : m .schema ,
723+ Versions : versions ,
724+ }); err != nil {
725+ return fmt .Errorf ("error inserting migration rows for versions %+v assuming main: %w" , versions , err )
726+ }
727+ } else {
728+ if _ , err := exec .MigrationInsertMany (ctx , & riverdriver.MigrationInsertManyParams {
729+ Line : m .line ,
730+ Schema : m .schema ,
731+ Versions : versions ,
732+ }); err != nil {
733+ return fmt .Errorf ("error inserting migration rows for versions %+v on line %q: %w" , versions , m .line , err )
734+ }
735+ }
736+
737+ return nil
738+ }
739+
686740// Reads a series of migration bundles from a file system, which practically
687741// speaking will always be the embedded FS read from the contents of the
688742// `migration/<line>/` subdirectory.
0 commit comments