diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index fd6e414605..c0c957f20e 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1323,6 +1323,16 @@ def logcdf(value, a, b): msg="a > 0, b > 0", ) + def icdf(value, a, b): + res = pt.exp(pt.reciprocal(a) * pt.log1mexp(pt.reciprocal(b) * pt.log1p(-value))) + res = check_icdf_value(res, value) + return check_icdf_parameters( + res, + a > 0, + b > 0, + msg="a > 0, b > 0", + ) + class Exponential(PositiveContinuous): r""" diff --git a/pymc/testing.py b/pymc/testing.py index 0405cc2f6c..551356ee33 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -668,6 +668,52 @@ def check_selfconsistency_discrete_logcdf( ) +def check_selfconsistency_icdf( + distribution: Distribution, + paramdomains: dict[str, Domain], + *, + decimal: int | None = None, + n_samples: int = 100, +) -> None: + """Check that the icdf and logcdf functions of the distribution are consistent. + + Only works with continuous distributions. + """ + if decimal is None: + decimal = select_by_precision(float64=6, float32=3) + + dist = create_dist_from_paramdomains(distribution, paramdomains) + if dist.type.dtype.startswith("int"): + raise NotImplementedError( + "check_selfconsistency_icdf is not robust against discrete distributions." + ) + value = dist.astype("float64").type("value") + dist_icdf = icdf(dist, value) + dist_cdf = pt.exp(logcdf(dist, value)) + + py_mode = Mode("py") + dist_icdf_fn = pytensor.function(list(inputvars(dist_icdf)), dist_icdf, mode=py_mode) + dist_cdf_fn = compile(list(inputvars(dist_cdf)), dist_cdf, mode=py_mode) + + domains = paramdomains.copy() + domains["value"] = Domain(np.linspace(0, 1, 10)) + + for point in product(domains, n_samples=n_samples): + point = dict(point) + value = point.pop("value") + icdf_value = dist_icdf_fn(**point, value=value) + recovered_value = dist_cdf_fn( + **point, + value=icdf_value, + ) + np.testing.assert_almost_equal( + value, + recovered_value, + decimal=decimal, + err_msg=f"point: {point}", + ) + + def assert_support_point_is_expected(model, expected, check_finite_logp=True): fn = make_initial_point_fn( model=model, diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 7209382666..ef21aea1a3 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -45,6 +45,7 @@ check_icdf, check_logcdf, check_logp, + check_selfconsistency_icdf, continuous_random_tester, seeded_numpy_distribution_builder, seeded_scipy_distribution_builder, @@ -441,6 +442,10 @@ def scipy_log_cdf(value, a, b): {"a": Rplus, "b": Rplus}, scipy_log_cdf, ) + check_selfconsistency_icdf( + pm.Kumaraswamy, + {"a": Rplusbig, "b": Rplusbig}, + ) def test_exponential(self): check_logp(