Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ def add_coord(
length = len(values)
if not isinstance(length, Variable):
if mutable:
length = aesara.shared(length)
length = aesara.shared(length, name=name)
else:
length = aesara.tensor.constant(length)
self._dim_lengths[name] = length
Expand Down
68 changes: 62 additions & 6 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,8 @@ def compile_forward_sampling_function(
vars_in_trace: List[Variable],
basic_rvs: Optional[List[Variable]] = None,
givens_dict: Optional[Dict[Variable, Any]] = None,
constant_data: Optional[Dict[str, np.ndarray]] = None,
constant_coords: Optional[Set[str]] = None,
**kwargs,
) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]:
"""Compile a function to draw samples, conditioned on the values of some variables.
Expand All @@ -1634,18 +1636,18 @@ def compile_forward_sampling_function(
compiled function or after inference has been run. These variables are:

- Variables in the outputs list
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time
- Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
- Variables that are keys in the ``givens_dict``
- Variables that have volatile inputs

Concretely, this function can be used to compile a function to sample from the
posterior predictive distribution of a model that has variables that are conditioned
on ``MutableData`` instances. The variables that depend on the mutable data will be
considered volatile, and as such, they wont be included as inputs into the compiled function.
This means that if they have values stored in the posterior, these values will be ignored
and new values will be computed (in the case of deterministics and potentials) or sampled
(in the case of random variables).
on ``MutableData`` instances. The variables that depend on the mutable data that have changed
will be considered volatile, and as such, they wont be included as inputs into the compiled
function. This means that if they have values stored in the posterior, these values will be
ignored and new values will be computed (in the case of deterministics and potentials) or
sampled (in the case of random variables).

This function also enables a way to impute values for any variable in the computational
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
Expand All @@ -1672,6 +1674,25 @@ def compile_forward_sampling_function(
A dictionary that maps tensor variables to the values that should be used to replace them
in the compiled function. The types of the key and value should match or an error will be
raised during compilation.
constant_data : Optional[Dict[str, numpy.ndarray]]
A dictionary that maps the names of ``MutableData`` or ``ConstantData`` instances to their
corresponding values at inference time. If a model was created with ``MutableData``, these
are stored as ``SharedVariable`` with the name of the data variable and a value equal to
the initial data. At inference time, this information is stored in ``InferenceData``
objects under the ``constant_data`` group, which allows us to check whether a
``SharedVariable`` instance changed its values after inference or not. If the values have
changed, then the ``SharedVariable`` is assumed to be volatile. If it has not changed, then
the ``SharedVariable`` is assumed to not be volatile. If a ``SharedVariable`` is not found
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
Setting ``constant_data`` to ``None`` is equivalent to passing an empty dictionary.
constant_coords : Optional[Set[str]]
A set with the names of the mutable coordinates that have not changed their shape after
inference. If a model was created with mutable coordinates, these are stored as
``SharedVariable`` with the name of the coordinate and a value equal to the length of said
coordinate. This set let's us check if a ``SharedVariable`` is a mutated coordinate, in
which case, it is considered volatile. If a ``SharedVariable`` is not found
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set.

Returns
-------
Expand All @@ -1687,6 +1708,20 @@ def compile_forward_sampling_function(
if basic_rvs is None:
basic_rvs = []

if constant_data is None:
constant_data = {}
if constant_coords is None:
constant_coords = set()

# We define a helper function to check if shared values match to an array
def shared_value_matches(var):
try:
old_array_value = constant_data[var.name]
except KeyError:
return var.name in constant_coords
current_shared_value = var.get_value(borrow=True)
return np.array_equal(old_array_value, current_shared_value)

# We need a function graph to walk the clients and propagate the volatile property
fg = FunctionGraph(outputs=outputs, clone=False)

Expand All @@ -1702,6 +1737,7 @@ def compile_forward_sampling_function(
or ( # SharedVariables, except RandomState/Generators
isinstance(node, SharedVariable)
and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
and not shared_value_matches(node)
)
or ( # Basic RVs that are not in the trace
node in basic_rvs and node not in vars_in_trace
Expand Down Expand Up @@ -1835,16 +1871,24 @@ def sample_posterior_predictive(
idata_kwargs = {}
else:
idata_kwargs = idata_kwargs.copy()
constant_data: Dict[str, np.ndarray] = {}
trace_coords: Dict[str, np.ndarray] = {}
if "coords" not in idata_kwargs:
idata_kwargs["coords"] = {}
if isinstance(trace, InferenceData):
idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"])
idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"])
_constant_data = getattr(trace, "constant_data", None)
if _constant_data is not None:
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()})
_trace = dataset_to_point_list(trace["posterior"])
nchain, len_trace = chains_and_samples(trace)
elif isinstance(trace, xarray.Dataset):
idata_kwargs["coords"].setdefault("draw", trace["draw"])
idata_kwargs["coords"].setdefault("chain", trace["chain"])
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
_trace = dataset_to_point_list(trace)
nchain, len_trace = chains_and_samples(trace)
elif isinstance(trace, MultiTrace):
Expand Down Expand Up @@ -1901,6 +1945,16 @@ def sample_posterior_predictive(
stacklevel=2,
)

constant_coords = set()
for dim, coord in trace_coords.items():
current_coord = model.coords.get(dim, None)
if (
current_coord is not None
and len(coord) == len(current_coord)
and np.all(coord == current_coord)
):
constant_coords.add(dim)

if var_names is not None:
vars_ = [model[x] for x in var_names]
else:
Expand Down Expand Up @@ -1935,6 +1989,8 @@ def sample_posterior_predictive(
basic_rvs=model.basic_RVs,
givens_dict=None,
random_seed=random_seed,
constant_data=constant_data,
constant_coords=constant_coords,
**compile_kwargs,
)
sampler_fn = point_wrapper(_sampler_fn)
Expand Down
Loading