diff --git a/docs/source/api/aesaraf.rst b/docs/source/api/aesaraf.rst index f610fbba84..3469cec8d8 100644 --- a/docs/source/api/aesaraf.rst +++ b/docs/source/api/aesaraf.rst @@ -10,17 +10,17 @@ Aesara utils gradient hessian hessian_diag + jacobian inputvars cont_inputs floatX intX smartfloatX - jacobian + constant_fold CallableTensor join_nonshared_inputs make_shared_replacements generator set_at_rng at_rng - take_along_axis pandas_to_array diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 62217bb5f2..42a4d025c9 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -34,7 +34,7 @@ from aesara import scalar from aesara.compile.mode import Mode, get_mode from aesara.gradient import grad -from aesara.graph import node_rewriter +from aesara.graph import node_rewriter, rewrite_graph from aesara.graph.basic import ( Apply, Constant, @@ -55,10 +55,13 @@ RandomGeneratorSharedVariable, RandomStateSharedVariable, ) +from aesara.tensor.rewriting.basic import topo_constant_folding +from aesara.tensor.rewriting.shape import ShapeFeature from aesara.tensor.sharedvar import SharedVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 from aesara.tensor.var import TensorConstant, TensorVariable +from pymc.exceptions import NotConstantValueError from pymc.vartypes import continuous_types, isgenerator, typefilter PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable] @@ -81,6 +84,8 @@ "set_at_rng", "at_rng", "convert_observed_data", + "compile_pymc", + "constant_fold", ] @@ -823,7 +828,7 @@ def find_rng_nodes( def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]: - """Replace any RNG nodes upsteram of outputs by new RNGs of the same type + """Replace any RNG nodes upstream of outputs by new RNGs of the same type This can be used when combining a pre-existing graph with a cloned one, to ensure RNGs are unique across the two graphs. @@ -970,3 +975,26 @@ def compile_pymc( **kwargs, ) return aesara_function + + +def constant_fold(xs: Sequence[TensorVariable]) -> Tuple[np.ndarray, ...]: + """Use constant folding to get constant values of a graph. + + Parameters + ---------- + xs: Sequence of TensorVariable + The variables that are to be constant folded + + Raises + ------ + NotConstantValueError: + If any of the variables cannot be successfully constant folded + """ + fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True) + + folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs + + if not all(isinstance(folded_x, Constant) for folded_x in folded_xs): + raise NotConstantValueError + + return tuple(folded_x.data for folded_x in folded_xs) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 7e0114a8cc..0996dec64a 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -27,10 +27,8 @@ from aeppl.tensor import MeasurableJoin from aeppl.transforms import TransformValuesRewrite from aesara import tensor as at -from aesara.graph import FunctionGraph, rewrite_graph from aesara.graph.basic import graph_inputs, io_toposort from aesara.tensor.random.op import RandomVariable -from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding from aesara.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -41,7 +39,8 @@ ) from aesara.tensor.var import TensorVariable -from pymc.aesaraf import floatX +from pymc.aesaraf import constant_fold, floatX +from pymc.exceptions import NotConstantValueError def _get_scaling( @@ -338,12 +337,10 @@ def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs): base_var_shapes = [base_var.shape[axis] for base_var in base_vars] - shape_fg = FunctionGraph( - outputs=base_var_shapes, - features=[ShapeFeature()], - clone=True, - ) - base_var_shapes = rewrite_graph(shape_fg, custom_opt=topo_constant_folding).outputs + try: + base_var_shapes = constant_fold(base_var_shapes) + except NotConstantValueError: + pass split_values = at.split( value, diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 8d8bc97c37..bb9f6a9da3 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -21,14 +21,12 @@ from aeppl.abstract import _get_measurable_outputs from aeppl.logprob import _logprob -from aesara.graph import FunctionGraph, rewrite_graph from aesara.graph.basic import Node, clone_replace from aesara.raise_op import Assert from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable -from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding -from pymc.aesaraf import convert_observed_data, floatX, intX +from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX from pymc.distributions import distribution, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.distribution import ( @@ -46,6 +44,7 @@ convert_dims, to_tuple, ) +from pymc.exceptions import NotConstantValueError from pymc.model import modelcontext from pymc.util import check_dist_not_registered @@ -472,14 +471,9 @@ def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: If inferred ar_order cannot be inferred from rhos or if it is less than 1 """ if ar_order is None: - shape_fg = FunctionGraph( - outputs=[rhos.shape[-1]], - features=[ShapeFeature()], - clone=True, - ) - (folded_shape,) = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs - folded_shape = getattr(folded_shape, "data", None) - if folded_shape is None: + try: + (folded_shape,) = constant_fold((rhos.shape[-1],)) + except NotConstantValueError: raise ValueError( "Could not infer ar_order from last dimension of rho. Pass it " "explictily or make sure rho have a static shape" diff --git a/pymc/exceptions.py b/pymc/exceptions.py index 0d7ba3eaaf..5b4141f303 100644 --- a/pymc/exceptions.py +++ b/pymc/exceptions.py @@ -76,5 +76,9 @@ def __init__(self, message, actual=None, expected=None): super().__init__(message) -class TruncationError(Exception): +class TruncationError(RuntimeError): """Exception for errors generated from truncated graphs""" + + +class NotConstantValueError(ValueError): + pass diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 8fc641d01c..71ddea89c8 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -35,6 +35,7 @@ from pymc.aesaraf import ( compile_pymc, + constant_fold, convert_observed_data, extract_obs_data, replace_rng_nodes, @@ -45,6 +46,7 @@ from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicRandomVariable from pymc.distributions.transforms import Interval +from pymc.exceptions import NotConstantValueError from pymc.vartypes import int_types @@ -610,3 +612,20 @@ def test_reseed_rngs(): assert rng.get_value()._bit_generator.state == bit_generator.state else: assert rng.get_value().bit_generator.state == bit_generator.state + + +def test_constant_fold(): + x = at.random.normal(size=(5,)) + y = at.arange(x.size) + + res = constant_fold((y, y.shape)) + assert np.array_equal(res[0], np.arange(5)) + assert tuple(res[1]) == (5,) + + +def test_constant_fold_error(): + x = at.vector("x") + y = at.arange(x.size) + + with pytest.raises(NotConstantValueError): + constant_fold((y, y.shape))