Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax.random as jr
import numpy as np
import numpyro
from jax.typing import ArrayLike
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample

Expand Down Expand Up @@ -101,6 +102,23 @@ def validate(**kwargs: object) -> None:
"""
pass

def scope(self) -> numpyro.handlers.scope:
"""
Standardized [`numpyro.handlers.scope`][] context for
PyRenew [`RandomVariable`][]s. This can be used to
naming of any internal sampling sites within the
[`RandomVariable`][]'s [`self.sample()`][] method.

The scope prefix is always the [`name`][self.name] of the `RandomVariable`
and the divider is always `_`.

Returns
-------
numpyro.handlers.scope
A properly configured scope handler.
"""
return numpyro.handlers.scope(prefix=self.name, divider="_")

def __call__(self, **kwargs: object) -> tuple:
"""
Alias for `sample`.
Expand Down
Loading