Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package migrate

import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -429,12 +430,24 @@ type SqlExecutor interface {
//
// Returns the number of applied migrations.
func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ExecMax(db, dialect, m, dir, 0)
return ExecMaxContext(context.Background(), db, dialect, m, dir, 0)
}

// Returns the number of applied migrations.
func (ms MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ms.ExecMax(db, dialect, m, dir, 0)
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, 0)
}

// Execute a set of migrations with an input context.
//
// Returns the number of applied migrations.
func ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ExecMaxContext(ctx, db, dialect, m, dir, 0)
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ms.ExecMaxContext(ctx, db, dialect, m, dir, 0)
}

// Execute a set of migrations
Expand All @@ -446,50 +459,78 @@ func ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirecti
return migSet.ExecMax(db, dialect, m, dir, max)
}

// Execute a set of migrations with an input context.
//
// Will apply at most `max` migrations. Pass 0 for no limit (or use Exec).
//
// Returns the number of applied migrations.
func ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
return migSet.ExecMaxContext(ctx, db, dialect, m, dir, max)
}

// Execute a set of migrations
//
// Will apply at the target `version` of migration. Cannot be a negative value.
//
// Returns the number of applied migrations.
func ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
return ExecVersionContext(context.Background(), db, dialect, m, dir, version)
}

// Execute a set of migrations with an input context.
//
// Will apply at the target `version` of migration. Cannot be a negative value.
//
// Returns the number of applied migrations.
func ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
if version < 0 {
return 0, fmt.Errorf("target version %d should not be negative", version)
}
return migSet.ExecVersion(db, dialect, m, dir, version)
return migSet.ExecVersionContext(ctx, db, dialect, m, dir, version)
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, max)
}

// Returns the number of applied migrations, but applies with an input context.
func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
migrations, dbMap, err := ms.PlanMigration(db, dialect, m, dir, max)
if err != nil {
return 0, err
}
return ms.applyMigrations(dir, migrations, dbMap)
return ms.applyMigrations(ctx, dir, migrations, dbMap)
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
return ms.ExecVersionContext(context.Background(), db, dialect, m, dir, version)
}

func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
migrations, dbMap, err := ms.PlanMigrationToVersion(db, dialect, m, dir, version)
if err != nil {
return 0, err
}
return ms.applyMigrations(dir, migrations, dbMap)
return ms.applyMigrations(ctx, dir, migrations, dbMap)
}

// Applies the planned migrations and returns the number of applied migrations.
func (MigrationSet) applyMigrations(dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
applied := 0
for _, migration := range migrations {
var executor SqlExecutor
var err error

if migration.DisableTransaction {
executor = dbMap
executor = dbMap.WithContext(ctx)
} else {
executor, err = dbMap.Begin()
e, err := dbMap.Begin()
if err != nil {
return applied, newTxError(migration, err)
}
executor = e.WithContext(ctx)
}

for _, stmt := range migration.Queries {
Expand Down
38 changes: 38 additions & 0 deletions migrate_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package migrate

import (
"context"
"database/sql"
"net/http"
"time"

"github.com/go-gorp/gorp/v3"
"github.com/gobuffalo/packr/v2"
Expand Down Expand Up @@ -757,3 +759,39 @@ func (s *SqliteMigrateSuite) TestGetMigrationDbMapWithDisableCreateTable(c *C) {
_, err := migSet.getMigrationDbMap(s.Db, "postgres")
c.Assert(err, IsNil)
}

func (s *SqliteMigrateSuite) TestContextTimeout(c *C) {
// This statement will run for a long time: 1,000,000 iterations of the fibonacci sequence
fibonacciLoopStmt := `WITH RECURSIVE
fibo (curr, next)
AS
( SELECT 1,1
UNION ALL
SELECT next, curr+next FROM fibo
LIMIT 1000000 )
SELECT group_concat(curr) FROM fibo;
`
migrations := &MemoryMigrationSource{
Migrations: []*Migration{
sqliteMigrations[0],
sqliteMigrations[1],
{
Id: "125",
Up: []string{fibonacciLoopStmt},
Down: []string{}, // Not important here
},
{
Id: "125",
Up: []string{"INSERT INTO people (id, first_name) VALUES (1, 'Test')", "SELECT fail"},
Down: []string{}, // Not important here
},
},
}

// Should never run the insert
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancelFunc()
n, err := ExecContext(ctx, s.Db, "sqlite3", migrations, Up)
c.Assert(err, Not(IsNil))
c.Assert(n, Equals, 2)
}