From a565e36ef1d3e64a6ebb059d4846ed53226f7be5 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Tue, 23 Dec 2025 10:23:42 +0100 Subject: [PATCH 1/3] Make HalfCauchy a symbolic RV --- pymc/dims/distributions/scalar.py | 41 ++++++++++++----------- pymc/distributions/continuous.py | 45 ++++++++++++++++++-------- tests/distributions/test_continuous.py | 9 ++++-- tests/logprob/test_transform_value.py | 6 +++- tests/test_printing.py | 2 +- 5 files changed, 66 insertions(+), 37 deletions(-) diff --git a/pymc/dims/distributions/scalar.py b/pymc/dims/distributions/scalar.py index 6cdb3ccc73..87f6e36649 100644 --- a/pymc/dims/distributions/scalar.py +++ b/pymc/dims/distributions/scalar.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytensor.xtensor as ptx -import pytensor.xtensor.random as pxr +import pytensor.xtensor.random as ptxr from pytensor.xtensor import as_xtensor @@ -23,7 +23,7 @@ ) from pymc.distributions.continuous import Beta as RegularBeta from pymc.distributions.continuous import Gamma as RegularGamma -from pymc.distributions.continuous import HalfStudentTRV, flat, halfflat +from pymc.distributions.continuous import HalfCauchyRV, HalfStudentTRV, flat, halfflat def _get_sigma_from_either_sigma_or_tau(*, sigma, tau): @@ -40,7 +40,7 @@ def _get_sigma_from_either_sigma_or_tau(*, sigma, tau): class Flat(DimDistribution): - xrv_op = pxr.as_xrv(flat) + xrv_op = ptxr.as_xrv(flat) @classmethod def dist(cls, **kwargs): @@ -48,7 +48,7 @@ def dist(cls, **kwargs): class HalfFlat(PositiveDimDistribution): - xrv_op = pxr.as_xrv(halfflat, [], ()) + xrv_op = ptxr.as_xrv(halfflat, [], ()) @classmethod def dist(cls, **kwargs): @@ -56,7 +56,7 @@ def dist(cls, **kwargs): class Normal(DimDistribution): - xrv_op = pxr.normal + xrv_op = ptxr.normal @classmethod def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): @@ -65,7 +65,7 @@ def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): class HalfNormal(PositiveDimDistribution): - xrv_op = pxr.halfnormal + xrv_op = ptxr.halfnormal @classmethod def dist(cls, sigma=None, *, tau=None, **kwargs): @@ -74,7 +74,7 @@ def dist(cls, sigma=None, *, tau=None, **kwargs): class LogNormal(PositiveDimDistribution): - xrv_op = pxr.lognormal + xrv_op = ptxr.lognormal @classmethod def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): @@ -83,7 +83,7 @@ def dist(cls, mu=0, sigma=None, *, tau=None, **kwargs): class StudentT(DimDistribution): - xrv_op = pxr.t + xrv_op = ptxr.t @classmethod def dist(cls, nu, mu=0, sigma=None, *, lam=None, **kwargs): @@ -102,12 +102,12 @@ def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None): nu = as_xtensor(nu) sigma = as_xtensor(sigma) core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op - xop = pxr.as_xrv(core_rv) + xop = ptxr.as_xrv(core_rv) return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng) class Cauchy(DimDistribution): - xrv_op = pxr.cauchy + xrv_op = ptxr.cauchy @classmethod def dist(cls, alpha, beta, **kwargs): @@ -115,15 +115,20 @@ def dist(cls, alpha, beta, **kwargs): class HalfCauchy(PositiveDimDistribution): - xrv_op = pxr.halfcauchy - @classmethod def dist(cls, beta, **kwargs): - return super().dist([0.0, beta], **kwargs) + return super().dist([beta], **kwargs) + + @classmethod + def xrv_op(self, beta, core_dims, extra_dims=None, rng=None): + beta = as_xtensor(beta) + core_rv = HalfCauchyRV.rv_op(beta=beta.values).owner.op + xop = ptxr.as_xrv(core_rv) + return xop(beta, core_dims=core_dims, extra_dims=extra_dims, rng=rng) class Beta(UnitDimDistribution): - xrv_op = pxr.beta + xrv_op = ptxr.beta @classmethod def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs): @@ -132,7 +137,7 @@ def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, nu=None, **kwargs): class Laplace(DimDistribution): - xrv_op = pxr.laplace + xrv_op = ptxr.laplace @classmethod def dist(cls, mu=0, b=1, **kwargs): @@ -140,7 +145,7 @@ def dist(cls, mu=0, b=1, **kwargs): class Exponential(PositiveDimDistribution): - xrv_op = pxr.exponential + xrv_op = ptxr.exponential @classmethod def dist(cls, lam=None, *, scale=None, **kwargs): @@ -154,7 +159,7 @@ def dist(cls, lam=None, *, scale=None, **kwargs): class Gamma(PositiveDimDistribution): - xrv_op = pxr.gamma + xrv_op = ptxr.gamma @classmethod def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): @@ -173,7 +178,7 @@ def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): class InverseGamma(PositiveDimDistribution): - xrv_op = pxr.invgamma + xrv_op = ptxr.invgamma @classmethod def dist(cls, alpha=None, beta=None, *, mu=None, sigma=None, **kwargs): diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 1d30227fed..8764012109 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -36,7 +36,6 @@ cauchy, exponential, gumbel, - halfcauchy, halfnormal, invgamma, laplace, @@ -2218,6 +2217,23 @@ def icdf(value, alpha, beta): ) +class HalfCauchyRV(SymbolicRandomVariable): + name = "halfcauchy" + extended_signature = "[rng],[size],()->[rng],()" + _print_name = ("HalfCauchy", "\\operatorname{HalfCauchy}") + + @classmethod + def rv_op(cls, beta, *, size=None, rng=None): + bt = pt.as_tensor(beta) + rng = normalize_rng_param(rng) + size = normalize_size_param(size) + + next_rng, cauchy_draws = cauchy(loc=0, scale=beta, size=size, rng=rng).owner.outputs + draws = pt.abs(cauchy_draws) + + return cls(inputs=[rng, size, beta], outputs=[next_rng, draws])(rng, size, beta) + + class HalfCauchy(PositiveContinuous): r""" Half-Cauchy distribution. @@ -2258,32 +2274,33 @@ class HalfCauchy(PositiveContinuous): Scale parameter (beta > 0). """ - rv_op = halfcauchy + rv_op = HalfCauchyRV.rv_op + rv_type = HalfCauchyRV @classmethod def dist(cls, beta, *args, **kwargs): beta = pt.as_tensor_variable(beta) - return super().dist([0.0, beta], **kwargs) + return super().dist([beta], **kwargs) - def support_point(rv, size, loc, beta): + def support_point(rv, size, beta): if not rv_size_is_none(size): beta = pt.full(size, beta) return beta - def logp(value, loc, beta): - res = pt.log(2) + _logprob_helper(Cauchy.dist(loc, beta), value) - res = pt.switch(pt.ge(value, loc), res, -np.inf) + def logp(value, beta): + res = pt.log(2) + _logprob_helper(Cauchy.dist(alpha=0, beta=beta), value) + res = pt.switch(value >= 0, res, -np.inf) return check_parameters( res, beta > 0, msg="beta > 0", ) - def logcdf(value, loc, beta): + def logcdf(value, beta): res = pt.switch( - pt.lt(value, loc), + value < 0, -np.inf, - pt.log(2 * pt.arctan((value - loc) / beta) / np.pi), + pt.log(2 * pt.arctan(value / beta) / np.pi), ) return check_parameters( @@ -2292,8 +2309,8 @@ def logcdf(value, loc, beta): msg="beta > 0", ) - def icdf(value, loc, beta): - res = loc + beta * pt.tan(np.pi * (value) / 2.0) + def icdf(value, beta): + res = beta * pt.tan(np.pi * (value) / 2.0) res = check_icdf_value(res, value) return check_icdf_parameters( res, @@ -2610,7 +2627,7 @@ class WeibullBetaRV(SymbolicRandomVariable): _print_name = ("Weibull", "\\operatorname{Weibull}") @classmethod - def rv_op(cls, alpha, beta, *, rng=None, size=None) -> np.ndarray: + def rv_op(cls, alpha, beta, *, rng=None, size=None): alpha = pt.as_tensor(alpha) beta = pt.as_tensor(beta) rng = normalize_rng_param(rng) @@ -2737,7 +2754,7 @@ class HalfStudentTRV(SymbolicRandomVariable): _print_name = ("HalfStudentT", "\\operatorname{HalfStudentT}") @classmethod - def rv_op(cls, nu, sigma, *, size=None, rng=None) -> np.ndarray: + def rv_op(cls, nu, sigma, *, size=None, rng=None): nu = pt.as_tensor(nu) sigma = pt.as_tensor(sigma) rng = normalize_rng_param(rng) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index eb0c4cabe7..2ca5335d67 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -2326,11 +2326,14 @@ class TestCauchy(BaseTestDistributionRandom): class TestHalfCauchy(BaseTestDistributionRandom): + def halfcauchy_rng_fn(self, scale, size, rng): + return np.abs(st.cauchy.rvs(scale=scale, size=size, random_state=rng)) + pymc_dist = pm.HalfCauchy pymc_dist_params = {"beta": 5.0} - expected_rv_op_params = {"alpha": 0.0, "beta": 5.0} - reference_dist_params = {"loc": 0.0, "scale": 5.0} - reference_dist = seeded_scipy_distribution_builder("halfcauchy") + expected_rv_op_params = {"beta": 5.0} + reference_dist_params = {"scale": 5.0} + reference_dist = lambda self: ft.partial(self.halfcauchy_rng_fn, rng=self.get_random_state()) # noqa: E731 checks_to_run = [ "check_pymc_params_match_rv_op", "check_pymc_draws_match_reference", diff --git a/tests/logprob/test_transform_value.py b/tests/logprob/test_transform_value.py index f7df2bff79..61e2a1356e 100644 --- a/tests/logprob/test_transform_value.py +++ b/tests/logprob/test_transform_value.py @@ -135,11 +135,15 @@ def test_original_values_output_dict(): lambda mu, sigma: sp.stats.lognorm(s=sigma, scale=np.exp(mu)), (), ), - ( + pytest.param( pt.random.halfcauchy, (1.5, 10.5), lambda alpha, beta: sp.stats.halfcauchy(loc=alpha, scale=beta), (), + marks=pytest.mark.xfail( + reason="We don't use PyTensor's HalfCauchy operator", + raises=NotImplementedError, + ), ), ( pt.random.gamma, diff --git a/tests/test_printing.py b/tests/test_printing.py index 28f3df68a9..7efba2a81e 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -257,7 +257,7 @@ def test_model_latex_repr_three_levels_model(): "$$", "\\begin{array}{rcl}", "\\text{mu} &\\sim & \\operatorname{Normal}(0,~5)\\\\\\text{sigma} &\\sim & " - "\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored\\_normal} &\\sim & " + "\\operatorname{HalfCauchy}(2.5)\\\\\\text{censored\\_normal} &\\sim & " "\\operatorname{Censored}(\\operatorname{Normal}(\\text{mu},~\\text{sigma}),~-2,~2)", "\\end{array}", "$$", From 4ff386c9107bf99c178b9f66bb0739a570f31f7f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 14 Dec 2025 12:31:11 +0100 Subject: [PATCH 2/3] Add pytensor linker matrix ("cvm", "numba") to tests Remove useless os paramerizations and update codecov name --- .github/workflows/tests.yml | 63 ++++++++++++-------------- scripts/check_all_tests_are_covered.py | 39 +++++++--------- 2 files changed, 45 insertions(+), 57 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3217f269f6..64e4511aa4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,8 +8,6 @@ on: branches: - main - - # Cancel running workflows for updated PRs # https://turso.tech/blog/simple-trick-to-save-environment-and-money-when-using-github-actions concurrency: @@ -26,7 +24,6 @@ concurrency: # enforces that test run just once per OS / floatX setting. jobs: - changes: name: "Check for changes" runs-on: ubuntu-latest @@ -56,8 +53,7 @@ jobs: if: ${{ needs.changes.outputs.changes == 'true' }} strategy: matrix: - os: [ubuntu-latest] - floatx: [float64] + linker: [cvm, numba] python-version: ["3.14"] test-subset: - | @@ -145,10 +141,10 @@ jobs: tests/dims/test_model.py fail-fast: false - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest env: TEST_SUBSET: ${{ matrix.test-subset }} - PYTENSOR_FLAGS: floatX=${{ matrix.floatx }} + PYTENSOR_FLAGS: linker=${{ matrix.linker }} defaults: run: shell: bash -leo pipefail {0} @@ -175,9 +171,9 @@ jobs: - name: Upload coverage to Codecov uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5.5.1 with: - token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads + token: ${{ secrets.CODECOV_TOKEN }} env_vars: TEST_SUBSET - name: ${{ matrix.os }} ${{ matrix.floatx }} + name: Ubuntu py${{ matrix.python-version }} linker=${{ matrix.linker }} fail_ci_if_error: false windows: @@ -185,8 +181,7 @@ jobs: if: ${{ needs.changes.outputs.changes == 'true' }} strategy: matrix: - os: [windows-latest] - floatx: [float64] + linker: [cvm, numba] python-version: ["3.11"] test-subset: - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py @@ -195,10 +190,10 @@ jobs: - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py fail-fast: false - runs-on: ${{ matrix.os }} + runs-on: windows-latest env: TEST_SUBSET: ${{ matrix.test-subset }} - PYTENSOR_FLAGS: floatX=${{ matrix.floatx }} + PYTENSOR_FLAGS: linker=${{ matrix.linker }} defaults: run: shell: bash -leo pipefail {0} @@ -225,9 +220,9 @@ jobs: - name: Upload coverage to Codecov uses: codecov/codecov-action@5a1091511ad55cbe89839c7260b706298ca349f7 # v5.5.1 with: - token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads + token: ${{ secrets.CODECOV_TOKEN }} env_vars: TEST_SUBSET - name: ${{ matrix.os }} ${{ matrix.floatx }} + name: Windows py${{ matrix.python-version }} linker=${{ matrix.linker }} fail_ci_if_error: false macos: @@ -235,8 +230,7 @@ jobs: if: ${{ needs.changes.outputs.changes == 'true' }} strategy: matrix: - os: [macos-latest] - floatx: [float64] + linker: [cvm, numba] python-version: ["3.14"] test-subset: - | @@ -253,10 +247,10 @@ jobs: tests/backends/test_zarr.py tests/variational/test_updates.py fail-fast: false - runs-on: ${{ matrix.os }} + runs-on: macos-latest env: TEST_SUBSET: ${{ matrix.test-subset }} - PYTENSOR_FLAGS: floatX=${{ matrix.floatx }} + PYTENSOR_FLAGS: linker=${{ matrix.linker }} defaults: run: shell: bash -leo pipefail {0} @@ -285,7 +279,7 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads env_vars: TEST_SUBSET - name: ${{ matrix.os }} ${{ matrix.floatx }} + name: MacOS py${{ matrix.python-version }} linker=${{ matrix.linker }} fail_ci_if_error: false alternative_backends: @@ -293,8 +287,7 @@ jobs: if: ${{ needs.changes.outputs.changes == 'true' }} strategy: matrix: - os: [ubuntu-latest] - floatx: [float64] + linker: [cvm, numba] # nutpie depends on PyMC, and it will require an extra release cycle to support # the next PyMC release and therefore Python 3.14. python-version: ["3.13"] @@ -305,10 +298,10 @@ jobs: tests/sampling/test_mcmc_external.py fail-fast: false - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest env: TEST_SUBSET: ${{ matrix.test-subset }} - PYTENSOR_FLAGS: floatX=${{ matrix.floatx }} + PYTENSOR_FLAGS: linker=${{ matrix.linker }} defaults: run: shell: bash -leo pipefail {0} @@ -337,7 +330,7 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads env_vars: TEST_SUBSET - name: Alternative backend tests - ${{ matrix.os }} ${{ matrix.floatx }} + name: Alternative backends py${{ matrix.python-version }} linker=${{ matrix.linker }} fail_ci_if_error: false float32: @@ -345,16 +338,16 @@ jobs: if: ${{ needs.changes.outputs.changes == 'true' }} strategy: matrix: + linker: [cvm, numba] os: [windows-latest] - floatx: [float32] python-version: ["3.14"] test-subset: - - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py + - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py fail-fast: false runs-on: ${{ matrix.os }} env: TEST_SUBSET: ${{ matrix.test-subset }} - PYTENSOR_FLAGS: floatX=${{ matrix.floatx }} + PYTENSOR_FLAGS: floatX=float32 defaults: run: shell: bash -leo pipefail {0} @@ -383,19 +376,19 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads env_vars: TEST_SUBSET - name: ${{ matrix.os }} ${{ matrix.floatx }} + name: float32 ${{ matrix.os }} py${{ matrix.python-version }} linker=${{ matrix.linker }} fail_ci_if_error: false all_tests: if: ${{ always() }} runs-on: ubuntu-latest - needs: [ changes, ubuntu, windows, macos, alternative_backends, float32 ] + needs: [changes, ubuntu, windows, macos, alternative_backends, float32] steps: - name: Check build matrix status if: ${{ needs.changes.outputs.changes == 'true' && - ( needs.ubuntu.result != 'success' || - needs.windows.result != 'success' || - needs.macos.result != 'success' || - needs.alternative_backends.result != 'success' || - needs.float32.result != 'success' ) }} + ( needs.ubuntu.result != 'success' || + needs.windows.result != 'success' || + needs.macos.result != 'success' || + needs.alternative_backends.result != 'success' || + needs.float32.result != 'success' ) }} run: exit 1 diff --git a/scripts/check_all_tests_are_covered.py b/scripts/check_all_tests_are_covered.py index 23079338d6..59e5bec157 100644 --- a/scripts/check_all_tests_are_covered.py +++ b/scripts/check_all_tests_are_covered.py @@ -31,7 +31,7 @@ def find_testfiles(): def from_yaml(): - """Determine how often each test file is run per platform and floatX setting. + """Determine how often each test file is run per platform and linker setting. An exception is raised if tests run multiple times with the same configuration. """ @@ -41,25 +41,32 @@ def from_yaml(): wfname = wf.rstrip(".yml") wfdef = yaml.safe_load(open(Path(".github", "workflows", wf))) for jobname, jobdef in wfdef["jobs"].items(): + if jobname in ("float32", "all_tests"): + continue + runs_on = jobdef.get("runs-on", "unknown") + floatX = "float32" if jobname == "float32" else "float64" matrix = jobdef.get("strategy", {}).get("matrix", {}) if matrix: + # Some jobs are parametrized by os, for others it's fixed + matrix.setdefault("os", [runs_on]) + matrix.setdefault("floatX", [floatX]) matrices[(wfname, jobname)] = matrix else: _log.warning("No matrix in %s/%s", wf, jobname) - # Now create an empty DataFrame to count based on OS/floatX/testfile + # Now create an empty DataFrame to count based on OS/linker/testfile all_os = [] - all_floatX = [] + all_linker = [] for matrix in matrices.values(): all_os += matrix["os"] - all_floatX += matrix["floatx"] + all_linker += matrix["linker"] all_os = tuple(sorted(set(all_os))) - all_floatX = tuple(sorted(set(all_floatX))) + all_linker = tuple(sorted(set(all_linker))) all_tests = find_testfiles() df = pandas.DataFrame( columns=pandas.MultiIndex.from_product( - [sorted(all_floatX), sorted(all_os)], names=["floatX", "os"] + [sorted(all_linker), sorted(all_os)], names=["linker", "os"] ), index=pandas.Index(sorted(all_tests), name="testfile"), ) @@ -67,22 +74,10 @@ def from_yaml(): # Count how often the testfiles are included in job definitions for matrix in matrices.values(): - for os_, floatX, subset in itertools.product( - matrix["os"], matrix["floatx"], matrix["test-subset"] + for os_, linker, subset in itertools.product( + matrix["os"], matrix["linker"], matrix["test-subset"] ): lines = [k for k in subset.split("\n") if k] - if "windows" in os_: - # Windows jobs need \ in line breaks within the test-subset! - # The following checks that these trailing \ are present in - # all items except the last. - if lines and lines[-1].endswith(" \\"): - raise Exception( - f"Last entry '{lines}' in Windows test subset should end WITHOUT ' \\'." - ) - for line in lines[:-1]: - if not line.endswith(" \\"): - raise Exception(f"Missing ' \\' after '{line}' in Windows test-subset.") - lines = [line.rstrip(" \\") for line in lines] # Unpack lines with >1 item testfiles = [] @@ -97,7 +92,7 @@ def from_yaml(): included = all_tests - ignored for testfile in included: - df.loc[testfile, (floatX, os_)] += 1 + df.loc[testfile, (linker, os_)] += 1 ignored_by_all = set(df[df.eq(0).all(axis=1)].index) run_multiple_times = set(df[df.gt(1).any(axis=1)].index) @@ -111,7 +106,7 @@ def from_yaml(): ) if run_multiple_times: raise AssertionError( - f"{len(run_multiple_times)} tests are run multiple times with the same OS and floatX setting:\n{run_multiple_times}" + f"{len(run_multiple_times)} tests are run multiple times with the same OS and pytensor flags:\n{run_multiple_times}" ) return From a471e75049ec77a18c17e650a269edad3378c2b7 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 24 Dec 2025 00:18:30 +0100 Subject: [PATCH 3/3] Temporary workaround for Dirichlet --- pymc/distributions/multivariate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index c167080b10..60c7cdd70b 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -544,7 +544,8 @@ class Dirichlet(SimplexContinuous): @classmethod def dist(cls, a, **kwargs): - a = pt.as_tensor_variable(a) + # dtype="floatX" is a temporary work-around for numba impl of dirichlet with discrete alpha + a = pt.as_tensor_variable(a, dtype="floatX") # mean = a / pt.sum(a) # mode = pt.switch(pt.all(a > 1), (a - 1) / pt.sum(a - 1), np.nan)