diff --git a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go index 057b6f2f..adcc9400 100644 --- a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go +++ b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go @@ -3275,20 +3275,13 @@ func exerciseListener[TTx any](ctx context.Context, t *testing.T, driverWithPool t.Run("SchemaFromSearchPath", func(t *testing.T) { t.Parallel() - // TODO(brandur): Need to find a way to make this test work. We need to - // inject a `search_path`, but the connection is acquired below inside - // `listener.Connect`, which means we can't do so here without finding a - // way to do some kind of test injection. - t.Skip("needs a way to be test injectable") - - // somehow do: - // SET search_path TO 'public' - var ( driver, _ = driverWithPool(ctx, t) listener = driver.GetListener("") ) + listener.SetAfterConnectExec("SET search_path TO 'public'") + connectListener(ctx, t, listener) require.Equal(t, "public", listener.Schema()) }) diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index 5315b911..73b7aad9 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -219,6 +219,7 @@ type Listener interface { Listen(ctx context.Context, topic string) error Ping(ctx context.Context) error Schema() string + SetAfterConnectExec(sql string) // should only ever be used in testing Unlisten(ctx context.Context, topic string) error WaitForNotification(ctx context.Context) (*Notification, error) } diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index 2808f54a..f0cace1e 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -764,11 +764,12 @@ func (t *ExecutorTx) Rollback(ctx context.Context) error { } type Listener struct { - conn *pgx.Conn - dbPool *pgxpool.Pool - prefix string // schema with a dot on the end (very minor optimization) - mu sync.Mutex - schema string + afterConnectExec string // should only ever be used in testing + conn *pgx.Conn + dbPool *pgxpool.Pool + prefix string // schema with a dot on the end (very minor optimization) + mu sync.Mutex + schema string } func (l *Listener) Close(ctx context.Context) error { @@ -805,6 +806,12 @@ func (l *Listener) Connect(ctx context.Context) error { return err } + if l.afterConnectExec != "" { + if _, err := poolConn.Exec(ctx, l.afterConnectExec); err != nil { + return err + } + } + // Use a configured schema if non-empty, otherwise try to select the current // schema based on `search_path`. schema := l.schema @@ -850,6 +857,13 @@ func (l *Listener) Schema() string { return l.schema } +func (l *Listener) SetAfterConnectExec(sql string) { + l.mu.Lock() + defer l.mu.Unlock() + + l.afterConnectExec = sql +} + func (l *Listener) Unlisten(ctx context.Context, topic string) error { l.mu.Lock() defer l.mu.Unlock()