From b15a5d639cf48ad4018082e5472ca7e40a6cbd40 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 26 Feb 2026 13:39:01 -0500 Subject: [PATCH 1/4] Add scope method --- pyrenew/metaclass.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index be594b78..679628f7 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -6,6 +6,7 @@ import jax.random as jr import numpy as np +import numpyro from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample @@ -101,6 +102,23 @@ def validate(**kwargs: object) -> None: """ pass + def scope(self) -> numpyro.handlers.scope: + """ + Standardized [`numpyro.handlers.scope`][] context for + PyRenew [`RandomVariable`][]s. This can be used to + naming of any internal sampling sites within the + [`RandomVariable`][]'s [`self.sample()`][] method. + + The scope prefix is always the [`name`][self.name] of the `RandomVariable` + and the divider is always `_`. + + Returns + ------- + numpyro.handlers.scope + A properly configured scope handler. + """ + return numpyro.handlers.scope(prefix=self.name, divider="_") + def __call__(self, **kwargs: object) -> tuple: """ Alias for `sample`. From fcd09269f0bda8f6abe036a678cb18f7eb3a00d7 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 9 Mar 2026 14:22:48 -0400 Subject: [PATCH 2/4] Add scope to process module --- pyrenew/process/ar.py | 13 ++++----- pyrenew/process/differencedprocess.py | 21 ++++++++------- pyrenew/process/iidrandomsequence.py | 38 +++++++++++++-------------- pyrenew/process/periodiceffect.py | 11 ++++---- pyrenew/process/randomwalk.py | 4 +-- 5 files changed, 45 insertions(+), 42 deletions(-) diff --git a/pyrenew/process/ar.py b/pyrenew/process/ar.py index 4cde78f2..1ec47db0 100644 --- a/pyrenew/process/ar.py +++ b/pyrenew/process/ar.py @@ -159,12 +159,13 @@ def transition( recent_vals: ArrayLike, _: ArrayLike ) -> tuple[ArrayLike, ArrayLike]: # numpydoc ignore=GL08 with numpyro.handlers.reparam(config={noise_name: LocScaleReparam(0)}): - next_noise = numpyro.sample( - noise_name, - numpyro.distributions.Normal( - loc=jnp.zeros(noise_shape), scale=noise_sd - ), - ) + with self.scope(): + next_noise = numpyro.sample( + noise_name, + numpyro.distributions.Normal( + loc=jnp.zeros(noise_shape), scale=noise_sd + ), + ) dot_prod = jnp.einsum("i...,i...->...", autoreg, recent_vals) new_term = dot_prod + next_noise diff --git a/pyrenew/process/differencedprocess.py b/pyrenew/process/differencedprocess.py index ba4f15bf..2dd652e2 100644 --- a/pyrenew/process/differencedprocess.py +++ b/pyrenew/process/differencedprocess.py @@ -172,15 +172,16 @@ def sample( ) n_diffs = n - self.differencing_order - if n_diffs > 0: - diff_samp = self.fundamental_process.sample( - *args, - n=n_diffs, - init_vals=fundamental_process_init_vals, - **kwargs, - ) - diffs = diff_samp - else: - diffs = jnp.array([]) + with self.scope(): + if n_diffs > 0: + diff_samp = self.fundamental_process.sample( + *args, + n=n_diffs, + init_vals=fundamental_process_init_vals, + **kwargs, + ) + diffs = diff_samp + else: + diffs = jnp.array([]) integrated_ts = integrate_discrete(init_vals, diffs)[:n] return integrated_ts diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index f80689dc..ab156513 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -37,9 +37,9 @@ def __init__( Returns ------- None - """ - super().__init__(name=name, **kwargs) + """ self.element_rv = element_rv + super().__init__(name=name, **kwargs) def sample( self, n: int, *args: object, vectorize: bool = False, **kwargs: object @@ -76,22 +76,22 @@ def sample( `n` samples from self.distribution`. """ - if vectorize and hasattr(self.element_rv, "expand_by"): - result = self.element_rv.expand_by((n,)).sample(*args, **kwargs) - else: - - def transition(_carry: None, _x: None) -> tuple[None, ArrayLike]: - # numpydoc ignore=GL08 - el = self.element_rv.sample(*args, **kwargs) - return None, el - - _, result = scan( - transition, - xs=None, - init=None, - length=n, - ) - + with self.scope(): + if vectorize and hasattr(self.element_rv, "expand_by"): + result = self.element_rv.expand_by((n,)).sample(*args, **kwargs) + else: + + def transition(_carry: None, _x: None) -> tuple[None, ArrayLike]: + # numpydoc ignore=GL08 + el = self.element_rv.sample(*args, **kwargs) + return None, el + + _, result = scan( + transition, + xs=None, + init=None, + length=n, + ) return result @staticmethod @@ -138,6 +138,6 @@ def __init__( super().__init__( name=name, element_rv=DistributionalVariable( - name=f"{name}_element", distribution=dist.Normal(0, 1) + name="element", distribution=dist.Normal(0, 1) ).expand_by(element_shape), ) diff --git a/pyrenew/process/periodiceffect.py b/pyrenew/process/periodiceffect.py index 8691d321..07c5d443 100644 --- a/pyrenew/process/periodiceffect.py +++ b/pyrenew/process/periodiceffect.py @@ -78,11 +78,12 @@ def sample(self, duration: int, **kwargs: object) -> ArrayLike: ArrayLike """ - return au.tile_until_n( - data=self.quantity_to_broadcast.sample(**kwargs), - n_timepoints=duration, - offset=self.offset, - ) + with self.scope(): + return au.tile_until_n( + data=self.quantity_to_broadcast.sample(**kwargs), + n_timepoints=duration, + offset=self.offset, + ) class DayOfWeekEffect(PeriodicEffect): diff --git a/pyrenew/process/randomwalk.py b/pyrenew/process/randomwalk.py index 97c6374e..cff83aca 100644 --- a/pyrenew/process/randomwalk.py +++ b/pyrenew/process/randomwalk.py @@ -45,7 +45,7 @@ class constructor. super().__init__( name=name, fundamental_process=IIDRandomSequence( - name=f"{name}_iid_seq", element_rv=step_rv + name="iid_seq", element_rv=step_rv ), differencing_order=1, **kwargs, @@ -85,7 +85,7 @@ def __init__( super().__init__( name=name, step_rv=DistributionalVariable( - name=f"{name}_step", distribution=dist.Normal(0.0, 1.0) + name="step", distribution=dist.Normal(0.0, 1.0) ), **kwargs, ) From d12ba6cd6ac8146a8eed2e6846eb950e5e21231b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 18:23:06 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyrenew/process/iidrandomsequence.py | 4 ++-- pyrenew/process/randomwalk.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index ab156513..2f280a5f 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -37,7 +37,7 @@ def __init__( Returns ------- None - """ + """ self.element_rv = element_rv super().__init__(name=name, **kwargs) @@ -85,7 +85,7 @@ def transition(_carry: None, _x: None) -> tuple[None, ArrayLike]: # numpydoc ignore=GL08 el = self.element_rv.sample(*args, **kwargs) return None, el - + _, result = scan( transition, xs=None, diff --git a/pyrenew/process/randomwalk.py b/pyrenew/process/randomwalk.py index cff83aca..76b16ec7 100644 --- a/pyrenew/process/randomwalk.py +++ b/pyrenew/process/randomwalk.py @@ -44,9 +44,7 @@ class constructor. """ super().__init__( name=name, - fundamental_process=IIDRandomSequence( - name="iid_seq", element_rv=step_rv - ), + fundamental_process=IIDRandomSequence(name="iid_seq", element_rv=step_rv), differencing_order=1, **kwargs, ) From 5e9717886db77e25fa68b47ed3429804ff595840 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Mon, 9 Mar 2026 18:23:53 -0400 Subject: [PATCH 4/4] Remove record --- pyrenew/deterministic/deterministic.py | 8 ++------ pyrenew/randomvariable/transformedvariable.py | 17 ++++++----------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/pyrenew/deterministic/deterministic.py b/pyrenew/deterministic/deterministic.py index 9e9b9c71..70141f23 100644 --- a/pyrenew/deterministic/deterministic.py +++ b/pyrenew/deterministic/deterministic.py @@ -72,13 +72,10 @@ def sample( **kwargs: object, ) -> ArrayLike: """ - Retrieve the value of the deterministic Rv + Retrieve the value of the deterministic RV Parameters ---------- - record - Whether to record the value of the deterministic - RandomVariable. Defaults to False. **kwargs Additional keyword arguments passed through to internal sample calls, should there be any. @@ -87,6 +84,5 @@ def sample( ------- ArrayLike """ - if record: - numpyro.deterministic(self.name, self) + numpyro.deterministic(self.name, self.value) return self.value diff --git a/pyrenew/randomvariable/transformedvariable.py b/pyrenew/randomvariable/transformedvariable.py index 22822d6e..1542ceb3 100644 --- a/pyrenew/randomvariable/transformedvariable.py +++ b/pyrenew/randomvariable/transformedvariable.py @@ -50,7 +50,7 @@ def __init__( self.transforms = transforms self.validate() - def sample(self, record: bool = False, **kwargs: object) -> tuple: + def sample(self, **kwargs: object) -> tuple: """ Sample method. Call self.base_rv.sample() and then apply the transforms specified @@ -58,9 +58,6 @@ def sample(self, record: bool = False, **kwargs: object) -> tuple: Parameters ---------- - record - Whether to record the value of the deterministic - RandomVariable. Defaults to False. **kwargs Keyword arguments passed to self.base_rv.sample() @@ -79,10 +76,11 @@ def sample(self, record: bool = False, **kwargs: object) -> tuple: t(uv) for t, uv in zip(self.transforms, untransformed_values) ) - if record: - if len(untransformed_values) == 1: - numpyro.deterministic(self.name, transformed_values) - else: + if len(transformed_values) == 1: + transformed_values = transformed_values[0] + numpyro.deterministic(self.name, transformed_values) + else: + with self.scope(): suffixes = ( untransformed_values._fields if hasattr(untransformed_values, "_fields") @@ -91,9 +89,6 @@ def sample(self, record: bool = False, **kwargs: object) -> tuple: for suffix, tv in zip(suffixes, transformed_values): numpyro.deterministic(f"{self.name}_{suffix}", tv) - if len(transformed_values) == 1: - transformed_values = transformed_values[0] - return transformed_values def sample_length(self) -> int: