From 726fbd6a7e1bb366d6618cf5e60328314ebf69fc Mon Sep 17 00:00:00 2001 From: Brandur Date: Sat, 26 Apr 2025 02:25:12 -0700 Subject: [PATCH] Unskip test to check that listener can take prefix from `search_path` The change in #848 added a capability to the listener so that in case a schema is _not_ explicitly provided, it can be detected from `search_path` so that multiple Rivers could share a database as long as their search paths are configured correctly. This was tricky to test though because listeners acquire their own connections, and we don't want to set search path on the database pool because it'd taint a connection on it and we wouldn't be guaranteed to get that connection for the listener anyway. I marked the test case as skipped until I could go and look at it again. Here, add a function to listeners that causes them to invoke an SQL command after a connection is acquired. This should only be used in tests, but lets us check this case. It's a little gnarly to be sure, but the river driver interface is explicitly marked as unstable, and I can't think of any alternatives that aren't worse. --- .../riverdrivertest/riverdrivertest.go | 11 ++------- riverdriver/river_driver_interface.go | 1 + riverdriver/riverpgxv5/river_pgx_v5_driver.go | 24 +++++++++++++++---- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go index 057b6f2f..adcc9400 100644 --- a/internal/riverinternaltest/riverdrivertest/riverdrivertest.go +++ b/internal/riverinternaltest/riverdrivertest/riverdrivertest.go @@ -3275,20 +3275,13 @@ func exerciseListener[TTx any](ctx context.Context, t *testing.T, driverWithPool t.Run("SchemaFromSearchPath", func(t *testing.T) { t.Parallel() - // TODO(brandur): Need to find a way to make this test work. We need to - // inject a `search_path`, but the connection is acquired below inside - // `listener.Connect`, which means we can't do so here without finding a - // way to do some kind of test injection. - t.Skip("needs a way to be test injectable") - - // somehow do: - // SET search_path TO 'public' - var ( driver, _ = driverWithPool(ctx, t) listener = driver.GetListener("") ) + listener.SetAfterConnectExec("SET search_path TO 'public'") + connectListener(ctx, t, listener) require.Equal(t, "public", listener.Schema()) }) diff --git a/riverdriver/river_driver_interface.go b/riverdriver/river_driver_interface.go index 5315b911..73b7aad9 100644 --- a/riverdriver/river_driver_interface.go +++ b/riverdriver/river_driver_interface.go @@ -219,6 +219,7 @@ type Listener interface { Listen(ctx context.Context, topic string) error Ping(ctx context.Context) error Schema() string + SetAfterConnectExec(sql string) // should only ever be used in testing Unlisten(ctx context.Context, topic string) error WaitForNotification(ctx context.Context) (*Notification, error) } diff --git a/riverdriver/riverpgxv5/river_pgx_v5_driver.go b/riverdriver/riverpgxv5/river_pgx_v5_driver.go index 2808f54a..f0cace1e 100644 --- a/riverdriver/riverpgxv5/river_pgx_v5_driver.go +++ b/riverdriver/riverpgxv5/river_pgx_v5_driver.go @@ -764,11 +764,12 @@ func (t *ExecutorTx) Rollback(ctx context.Context) error { } type Listener struct { - conn *pgx.Conn - dbPool *pgxpool.Pool - prefix string // schema with a dot on the end (very minor optimization) - mu sync.Mutex - schema string + afterConnectExec string // should only ever be used in testing + conn *pgx.Conn + dbPool *pgxpool.Pool + prefix string // schema with a dot on the end (very minor optimization) + mu sync.Mutex + schema string } func (l *Listener) Close(ctx context.Context) error { @@ -805,6 +806,12 @@ func (l *Listener) Connect(ctx context.Context) error { return err } + if l.afterConnectExec != "" { + if _, err := poolConn.Exec(ctx, l.afterConnectExec); err != nil { + return err + } + } + // Use a configured schema if non-empty, otherwise try to select the current // schema based on `search_path`. schema := l.schema @@ -850,6 +857,13 @@ func (l *Listener) Schema() string { return l.schema } +func (l *Listener) SetAfterConnectExec(sql string) { + l.mu.Lock() + defer l.mu.Unlock() + + l.afterConnectExec = sql +} + func (l *Listener) Unlisten(ctx context.Context, topic string) error { l.mu.Lock() defer l.mu.Unlock()