diff --git a/riverdbtest/riverdbtest.go b/riverdbtest/riverdbtest.go index 8b793225..72f026a1 100644 --- a/riverdbtest/riverdbtest.go +++ b/riverdbtest/riverdbtest.go @@ -428,22 +428,36 @@ func packageFromFunc(funcName string) string { return packageName } -// TestTx starts a test transaction that's rolled back automatically as the test -// case is cleaning itself up. +// TestTxPgx starts a test transaction that's rolled back automatically as the +// test case is cleaning itself up. // // This variant starts a transaction for the standard pgx/v5 driver most // commonly used throughout most of River. func TestTxPgx(ctx context.Context, tb testing.TB) pgx.Tx { tb.Helper() + tx, _ := TestTxPgxDriver(ctx, tb, riverpgxv5.New(riversharedtest.DBPool(ctx, tb)), nil) + return tx +} - tx, schema := TestTx(ctx, tb, riverpgxv5.New(riversharedtest.DBPool(ctx, tb)), &TestTxOpts{ - IsTestTxHelper: true, - }) +// TestTxPgxDriver starts a test transaction that's rolled back automatically as +// the test case is cleaning itself up. Unlike TestTxPgx, this variant takes a +// driver and options for greater flexibility, including allowing for Pro +// drivers, while still sharing common setup like schema search path. +func TestTxPgxDriver(ctx context.Context, tb testing.TB, driver riverdriver.Driver[pgx.Tx], opts *TestTxOpts) (pgx.Tx, string) { + tb.Helper() + + var optsCopy TestTxOpts + if opts != nil { + optsCopy = *opts + } + optsCopy.IsTestTxHelper = true + + tx, schema := TestTx(ctx, tb, driver, &optsCopy) _, err := tx.Exec(ctx, "SET search_path TO '"+schema+"'") require.NoError(tb, err) - return tx + return tx, schema } // TestTxOpts are options for TestTx. Most of the time these can be left as nil. @@ -465,7 +479,7 @@ type TestTxOpts struct { IsTestTxHelper bool // ProcurePool returns a database pool that will be set to the input driver - // using Driver.PoolSet. This is an optional parameter and should usuall be + // using Driver.PoolSet. This is an optional parameter and should usually be // left unset. It exists for use with SQLite to generate a database pool for // use in testing after a test schema is available because unlike other // databases, test schemas in SQLite (which are actually test databases) are @@ -620,7 +634,7 @@ func testTxSchemaForDatabaseAndMigrationLines[TTx any](ctx context.Context, tb t // for purposes of schema naming. skipExtraFrames := 2 if opts.IsTestTxHelper { - skipExtraFrames++ + skipExtraFrames += 2 } schema = TestSchema(ctx, tb, driver, &TestSchemaOpts{ diff --git a/riverdriver/riverdrivertest/driver_test.go b/riverdriver/riverdrivertest/driver_test.go index ae2a792f..7a92cb48 100644 --- a/riverdriver/riverdrivertest/driver_test.go +++ b/riverdriver/riverdrivertest/driver_test.go @@ -51,7 +51,7 @@ func TestDriverRiverDatabaseSQLLibPQ(t *testing.T) { tx, schema := riverdbtest.TestTx(ctx, t, driver, nil) - // The same thing as the built-in riversharedtest.TestTx does. + // The same thing as the built-in riverdbtest.TestTxPgx does. _, err := tx.ExecContext(ctx, "SET search_path TO '"+schema+"'") require.NoError(t, err) @@ -81,7 +81,7 @@ func TestDriverRiverDatabaseSQLPgx(t *testing.T) { tx, schema := riverdbtest.TestTx(ctx, t, driver, nil) - // The same thing as the built-in riversharedtest.TestTx does. + // The same thing as the built-in riverdbtest.TestTxPgx does. _, err := tx.ExecContext(ctx, "SET search_path TO '"+schema+"'") require.NoError(t, err) @@ -239,7 +239,7 @@ func BenchmarkDriverRiverDatabaseSQLLibPQ(b *testing.B) { tx, schema := riverdbtest.TestTx(ctx, b, driver, nil) - // The same thing as the built-in riversharedtest.TestTx does. + // The same thing as the built-in riverdbtest.TestTxPgx does. _, err := tx.ExecContext(ctx, "SET search_path TO '"+schema+"'") require.NoError(b, err) @@ -268,7 +268,7 @@ func BenchmarkDriverRiverDatabaseSQLPgx(b *testing.B) { tx, schema := riverdbtest.TestTx(ctx, b, driver, nil) - // The same thing as the built-in riversharedtest.TestTx does. + // The same thing as the built-in riverdbtest.TestTxPgx does. _, err := tx.ExecContext(ctx, "SET search_path TO '"+schema+"'") require.NoError(b, err)