diff --git a/migrate.go b/migrate.go index 31eacbe9..21bd86e8 100644 --- a/migrate.go +++ b/migrate.go @@ -2,6 +2,7 @@ package migrate import ( "bytes" + "context" "database/sql" "errors" "fmt" @@ -429,12 +430,24 @@ type SqlExecutor interface { // // Returns the number of applied migrations. func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { - return ExecMax(db, dialect, m, dir, 0) + return ExecMaxContext(context.Background(), db, dialect, m, dir, 0) } // Returns the number of applied migrations. func (ms MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { - return ms.ExecMax(db, dialect, m, dir, 0) + return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, 0) +} + +// Execute a set of migrations with an input context. +// +// Returns the number of applied migrations. +func ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { + return ExecMaxContext(ctx, db, dialect, m, dir, 0) +} + +// Returns the number of applied migrations. +func (ms MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { + return ms.ExecMaxContext(ctx, db, dialect, m, dir, 0) } // Execute a set of migrations @@ -446,50 +459,78 @@ func ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirecti return migSet.ExecMax(db, dialect, m, dir, max) } +// Execute a set of migrations with an input context. +// +// Will apply at most `max` migrations. Pass 0 for no limit (or use Exec). +// +// Returns the number of applied migrations. +func ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { + return migSet.ExecMaxContext(ctx, db, dialect, m, dir, max) +} + // Execute a set of migrations // // Will apply at the target `version` of migration. Cannot be a negative value. // // Returns the number of applied migrations. func ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { + return ExecVersionContext(context.Background(), db, dialect, m, dir, version) +} + +// Execute a set of migrations with an input context. +// +// Will apply at the target `version` of migration. Cannot be a negative value. +// +// Returns the number of applied migrations. +func ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { if version < 0 { return 0, fmt.Errorf("target version %d should not be negative", version) } - return migSet.ExecVersion(db, dialect, m, dir, version) + return migSet.ExecVersionContext(ctx, db, dialect, m, dir, version) } // Returns the number of applied migrations. func (ms MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { + return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, max) +} + +// Returns the number of applied migrations, but applies with an input context. +func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { migrations, dbMap, err := ms.PlanMigration(db, dialect, m, dir, max) if err != nil { return 0, err } - return ms.applyMigrations(dir, migrations, dbMap) + return ms.applyMigrations(ctx, dir, migrations, dbMap) } // Returns the number of applied migrations. func (ms MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { + return ms.ExecVersionContext(context.Background(), db, dialect, m, dir, version) +} + +func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) { migrations, dbMap, err := ms.PlanMigrationToVersion(db, dialect, m, dir, version) if err != nil { return 0, err } - return ms.applyMigrations(dir, migrations, dbMap) + return ms.applyMigrations(ctx, dir, migrations, dbMap) } // Applies the planned migrations and returns the number of applied migrations. -func (MigrationSet) applyMigrations(dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) { +func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) { applied := 0 for _, migration := range migrations { var executor SqlExecutor var err error if migration.DisableTransaction { - executor = dbMap + executor = dbMap.WithContext(ctx) } else { - executor, err = dbMap.Begin() + e, err := dbMap.Begin() if err != nil { return applied, newTxError(migration, err) } + executor = e.WithContext(ctx) } for _, stmt := range migration.Queries { diff --git a/migrate_test.go b/migrate_test.go index f9d2db2e..71183dc3 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -1,8 +1,10 @@ package migrate import ( + "context" "database/sql" "net/http" + "time" "github.com/go-gorp/gorp/v3" "github.com/gobuffalo/packr/v2" @@ -757,3 +759,39 @@ func (s *SqliteMigrateSuite) TestGetMigrationDbMapWithDisableCreateTable(c *C) { _, err := migSet.getMigrationDbMap(s.Db, "postgres") c.Assert(err, IsNil) } + +func (s *SqliteMigrateSuite) TestContextTimeout(c *C) { + // This statement will run for a long time: 1,000,000 iterations of the fibonacci sequence + fibonacciLoopStmt := `WITH RECURSIVE + fibo (curr, next) + AS + ( SELECT 1,1 + UNION ALL + SELECT next, curr+next FROM fibo + LIMIT 1000000 ) + SELECT group_concat(curr) FROM fibo; + ` + migrations := &MemoryMigrationSource{ + Migrations: []*Migration{ + sqliteMigrations[0], + sqliteMigrations[1], + { + Id: "125", + Up: []string{fibonacciLoopStmt}, + Down: []string{}, // Not important here + }, + { + Id: "125", + Up: []string{"INSERT INTO people (id, first_name) VALUES (1, 'Test')", "SELECT fail"}, + Down: []string{}, // Not important here + }, + }, + } + + // Should never run the insert + ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancelFunc() + n, err := ExecContext(ctx, s.Db, "sqlite3", migrations, Up) + c.Assert(err, Not(IsNil)) + c.Assert(n, Equals, 2) +}