diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index f45f0ccb00..6dbcf09971 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -66,6 +66,7 @@ from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType from pytensor.tensor.variable import TensorVariable +from pymc.exceptions import NotConstantValueError from pymc.logprob.abstract import ( MeasurableElemwise, MeasurableOp, @@ -82,6 +83,7 @@ ) from pymc.logprob.utils import ( check_potential_measurability, + dirac_delta, filter_measurable_variables, get_related_valued_nodes, ) @@ -414,21 +416,58 @@ def find_measurable_switch_mixture(fgraph, node): switch_cond, *components = node.inputs + # require at least one measurable component, otherwise there's no logprob to compute + measurable_components = filter_measurable_variables(components) + if not measurable_components: + return None + + # only allow non measurable components if they are compile time constants + measurable_ids = {id(c) for c in measurable_components} + non_measurable_components = [c for c in components if id(c) not in measurable_ids] + folded_constants: dict[int, TensorVariable] = {} + for comp in non_measurable_components: + if isinstance(comp, Constant): + folded_constants[id(comp)] = comp + continue + try: + (folded_comp,) = constant_fold([comp], raise_not_constant=True) + except NotConstantValueError: + return None + if not isinstance(folded_comp, TensorVariable): + folded_comp = pt.constant(folded_comp) + if not isinstance(folded_comp, Constant): + return None + folded_constants[id(comp)] = folded_comp + bcast_ref = measurable_components[ + 0 + ] # use a measurable component as broadcasting reference for constant branches + new_components: list[TensorVariable] = [] + for comp in components: + if id(comp) in measurable_ids: + new_components.append(comp) + else: + # treat constant branches as point masses so we can compute a logp. + # broadcasting is allowed for constants because it doesn't introduce dependence. + const_comp = folded_constants[id(comp)] + bcast_comp, _ = pt.broadcast_arrays(const_comp, bcast_ref) + new_components.append(dirac_delta(bcast_comp)) + # We don't support broadcasting of components, as that yields dependent (identical) values. # The current logp implementation assumes all component values are independent. # Broadcasting of the switch condition is fine out_bcast = node.outputs[0].type.broadcastable - if any(comp.type.broadcastable != out_bcast for comp in components): - return None - - if set(filter_measurable_variables(components)) != set(components): + if any( + comp.type.broadcastable != out_bcast + for comp in new_components + if id(comp) in measurable_ids + ): return None # Check that `switch_cond` is not potentially measurable if check_potential_measurability([switch_cond]): return None - return [measurable_switch_mixture(switch_cond, *components)] + return [measurable_switch_mixture(switch_cond, *new_components)] @_logprob.register(MeasurableSwitchMixture) diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index ee34c5acf4..6e1d414c7b 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -39,6 +39,7 @@ from numpy.lib.array_utils import normalize_axis_index from pytensor import tensor as pt +from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor import TensorVariable @@ -65,6 +66,7 @@ ) from pymc.logprob.utils import ( check_potential_measurability, + dirac_delta, filter_measurable_variables, get_related_valued_nodes, replace_rvs_by_values, @@ -162,9 +164,29 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None: else: base_vars = node.inputs - if not all(check_potential_measurability([base_var]) for base_var in base_vars): + # allow mixing potentially measurable inputs with compile time constants. + new_base_vars: list[TensorVariable] = [] + has_measurable = False + for base_var in base_vars: + if check_potential_measurability([base_var]): + has_measurable = True + new_base_vars.append(base_var) + else: + if isinstance(base_var, Constant): + folded_var = base_var + else: + try: + (folded_var,) = constant_fold([base_var], raise_not_constant=True) + except NotConstantValueError: + return None + if not isinstance(folded_var, TensorVariable): + folded_var = pt.constant(folded_var) + if not isinstance(folded_var, Constant): + return None + new_base_vars.append(dirac_delta(folded_var)) + if not has_measurable: return None - + base_vars = new_base_vars base_vars = assume_valued_outputs(base_vars) if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_vars): return None diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index eb9fc81488..23ac594801 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -52,6 +52,8 @@ as_index_constant, ) +import pymc as pm + from pymc.logprob.abstract import MeasurableOp from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices @@ -971,6 +973,22 @@ def test_switch_mixture_invalid_bcast(): assert not isinstance(fgraph.outputs[0].owner.inputs[0].owner.op, MeasurableOp) +def test_switch_mixture_constant_branch_broadcast_ok(): + t = pt.arange(10) + cat = pm.Categorical.dist(p=[0.5, 0.5], shape=(10,)) + cat_fixed_const = pt.where(t > 5, cat, -1) + cat_fixed_dirac = pt.where(t > 5, cat, pm.DiracDelta.dist(-1, shape=cat.shape)) + vv_const = cat_fixed_const.clone() + vv_dirac = cat_fixed_dirac.clone() + logp_const = logp(cat_fixed_const, vv_const, warn_rvs=False) + logp_dirac = logp(cat_fixed_dirac, vv_dirac, warn_rvs=False) + test_value = np.where(np.arange(10) > 5, 0, -1).astype(vv_const.dtype) + np.testing.assert_allclose( + logp_const.eval({vv_const: test_value}), + logp_dirac.eval({vv_dirac: test_value.astype(vv_dirac.dtype)}), + ) + + def test_ifelse_mixture_one_component(): if_rv = pt.random.bernoulli(0.5, name="if") scale_rv = pt.random.halfnormal(name="scale") diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index fe9a875977..ec64611f61 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -101,6 +101,30 @@ def test_measurable_make_vector(): assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval) +def test_measurable_make_vector_with_constant_input(): + base1_rv = pt.random.normal(name="base1") + base2_rv = pt.random.halfnormal(name="base2") + y_rv = pt.stack((base1_rv, pt.constant(0.0), base2_rv)) + y_rv.name = "y" + base1_vv = base1_rv.clone() + base2_vv = base2_rv.clone() + y_vv = y_rv.clone() + ref_logp = conditional_logp({base1_rv: base1_vv, base2_rv: base2_vv}) + ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()]) + y_logp = logp(y_rv, y_vv) + base1_testval = base1_rv.eval() + base2_testval = base2_rv.eval() + y_testval = np.stack((base1_testval, 0.0, base2_testval)).astype(y_vv.dtype) + ref_logp_eval = ref_logp_combined.eval({base1_vv: base1_testval, base2_vv: base2_testval}) + y_logp_eval = y_logp.eval({y_vv: y_testval}) + assert y_logp_eval.shape == y_testval.shape + assert np.isclose(y_logp_eval.sum(), ref_logp_eval) + y_testval_bad = y_testval.copy() + y_testval_bad[1] = 1.0 + y_logp_eval_bad = y_logp.eval({y_vv: y_testval_bad}) + assert y_logp_eval_bad[1] == -np.inf + + @pytest.mark.parametrize("reverse", (False, True)) def test_measurable_make_vector_interdependent(reverse): """Test that we can obtain a proper graph when stacked RVs depend on each other""" @@ -190,6 +214,31 @@ def test_measurable_join_interdependent(reverse): ) +def test_measurable_join_with_constant_input(): + base1_rv = pt.random.normal(size=(2,), name="base1") + base2_rv = pt.random.exponential(size=(3,), name="base2") + const = pt.constant(np.array([0.0, 0.0, 0.0])) + y_rv = pt.join(0, base1_rv, const, base2_rv) + y_rv.name = "y" + base1_vv = base1_rv.clone() + base2_vv = base2_rv.clone() + y_vv = y_rv.clone() + ref_logp = conditional_logp({base1_rv: base1_vv, base2_rv: base2_vv}) + ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()]) + y_logp = logp(y_rv, y_vv) + base1_testval = base1_rv.eval() + base2_testval = base2_rv.eval() + y_testval = np.concatenate([base1_testval, np.zeros(3), base2_testval]).astype(y_vv.dtype) + ref_logp_eval = ref_logp_combined.eval({base1_vv: base1_testval, base2_vv: base2_testval}) + y_logp_eval = y_logp.eval({y_vv: y_testval}) + assert y_logp_eval.shape == y_testval.shape + assert np.isclose(y_logp_eval.sum(), ref_logp_eval) + y_testval_bad = y_testval.copy() + y_testval_bad[2] = 1.0 + y_logp_eval_bad = y_logp.eval({y_vv: y_testval_bad}) + assert y_logp_eval_bad[2] == -np.inf + + @pytest.mark.parametrize( "size1, size2, axis, concatenate", [