Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,24 +977,28 @@ def compile_pymc(
return aesara_function


def constant_fold(xs: Sequence[TensorVariable]) -> Tuple[np.ndarray, ...]:
def constant_fold(
xs: Sequence[TensorVariable], raise_not_constant: bool = True
) -> Tuple[np.ndarray, ...]:
"""Use constant folding to get constant values of a graph.

Parameters
----------
xs: Sequence of TensorVariable
The variables that are to be constant folded

Raises
------
NotConstantValueError:
If any of the variables cannot be successfully constant folded
raise_not_constant: bool, default True
Raises NotConstantValueError if any of the variables cannot be constant folded.
This should only be disabled with care, as the graphs are cloned before
attempting constant folding, and any old non-shared inputs will not work with
the returned outputs
"""
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)

folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs

if not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
raise NotConstantValueError

return tuple(folded_x.data for folded_x in folded_xs)
return tuple(
folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs
)
7 changes: 2 additions & 5 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from aesara.tensor.var import TensorVariable

from pymc.aesaraf import constant_fold, floatX
from pymc.exceptions import NotConstantValueError


def _get_scaling(
Expand Down Expand Up @@ -337,10 +336,8 @@ def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs):

base_var_shapes = [base_var.shape[axis] for base_var in base_vars]

try:
base_var_shapes = constant_fold(base_var_shapes)
except NotConstantValueError:
pass
# We don't need the graph to be constant, just to have RandomVariables removed
base_var_shapes = constant_fold(base_var_shapes, raise_not_constant=False)

split_values = at.split(
value,
Expand Down
8 changes: 6 additions & 2 deletions pymc/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,13 @@ def test_constant_fold():
assert tuple(res[1]) == (5,)


def test_constant_fold_error():
x = at.vector("x")
def test_constant_fold_raises():
size = aesara.shared(5)
x = at.random.normal(size=(size,))
y = at.arange(x.size)

with pytest.raises(NotConstantValueError):
constant_fold((y, y.shape))

res = constant_fold((y, y.shape), raise_not_constant=False)
assert tuple(res[1].eval()) == (5,)