From 3414f80411b79c68ef45381aaaa2d5a4784ec0f0 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Sun, 30 Nov 2025 16:19:03 +0500 Subject: [PATCH 1/5] Fixed UserWarning when converting sample_stats to idata --- pymc/backends/arviz.py | 37 +++++++++++++++++++++++-------------- pymc/smc/sampling.py | 4 +++- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..4b4ede3935 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -29,7 +29,8 @@ import xarray from arviz import InferenceData, concat, rcParams -from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires +from arviz.data.base import CoordSpec, DimSpec, requires +from arviz_base import dict_to_dataset from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable from rich.progress import Console @@ -305,14 +306,14 @@ def posterior_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, @@ -347,14 +348,14 @@ def sample_stats_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, @@ -367,7 +368,11 @@ def posterior_predictive_to_xarray(self): data = self.posterior_predictive dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) @requires(["predictions"]) @@ -376,7 +381,11 @@ def predictions_to_xarray(self): data = self.predictions dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) def priors_to_xarray(self): @@ -399,7 +408,7 @@ def priors_to_xarray(self): if var_names is None else dict_to_dataset_drop_incompatible_coords( {k: np.expand_dims(self.prior[k], 0) for k in var_names}, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, ) @@ -414,10 +423,10 @@ def observed_data_to_xarray(self): return None return dict_to_dataset( self.observations, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) @requires("model") @@ -429,10 +438,10 @@ def constant_data_to_xarray(self): xarray_dataset = dict_to_dataset( constant_data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) # provisional handling of scalars in constant @@ -707,9 +716,9 @@ def apply_function_over_dataset( return dict_to_dataset( out_trace, - library=pymc, + inference_library=pymc, dims=dims, coords=coords, - default_dims=list(sample_dims), + sample_dims=list(sample_dims), skip_event_dims=True, ) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 249f6c5253..08d383db15 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -267,7 +267,9 @@ def _save_sample_stats( sample_stats = dict_to_dataset( sample_stats_dict, attrs=sample_settings_dict, - library=pymc, + inference_library=pymc, + sample_dims=["chain"], + check_conventions=False, ) ikwargs: dict[str, Any] = {"model": model} From fb7afafacf7ab4774d077f6f5336fdcba67c5260 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Sun, 30 Nov 2025 17:37:35 +0500 Subject: [PATCH 2/5] Added arviz-base to requirements list --- requirements-dev.txt | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0c8818d531..871a86d61e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ # See that file for comments about the need/usage of each dependency. arviz>=0.13.0 +arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle ipython>=7.16 diff --git a/requirements.txt b/requirements.txt index 7063fe5a81..1c82e66b93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ arviz>=0.13.0 +arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle numpy>=1.25.0 From 0ea080684e3e523f4138ad808c5d2aa71416e7e6 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Mon, 1 Dec 2025 02:19:18 +0500 Subject: [PATCH 3/5] Added arviz-base dependency in conda-envs/environment-dev.yml --- conda-envs/environment-dev.yml | 1 + requirements-dev.txt | 2 +- requirements.txt | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 1d46fa91cd..9064efbdcb 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - blas - cachetools>=4.2.1 - cloudpickle diff --git a/requirements-dev.txt b/requirements-dev.txt index 871a86d61e..6fca613919 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,8 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. +arviz-base arviz>=0.13.0 -arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle ipython>=7.16 diff --git a/requirements.txt b/requirements.txt index 1c82e66b93..7063fe5a81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ arviz>=0.13.0 -arviz-base>=0.7.0 cachetools>=4.2.1 cloudpickle numpy>=1.25.0 From ff01f4862032ca7d35cf804ae437b9ca748abed2 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Mon, 1 Dec 2025 02:55:50 +0500 Subject: [PATCH 4/5] Attempting to add arviz-base as a dependency --- conda-envs/environment-docs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 3d0fbcf819..2b21af774a 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - cachetools>=4.2.1 - cloudpickle - numpy>=1.25.0 From 1dfea292877345b0f5b6ce5b6e010cfb168acb5c Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Tue, 23 Dec 2025 00:03:27 +0500 Subject: [PATCH 5/5] Transferred imports of rcParams and requires from arviz to arviz_base --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-test.yml | 1 + pymc/backends/arviz.py | 8 +++++--- requirements-dev.txt | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 9064efbdcb..84bab0d7ea 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 -- arviz-base +- arviz-base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 2b21af774a..e7284f77b1 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 -- arviz-base +- arviz-base>=0.7.0 - cachetools>=4.2.1 - cloudpickle - numpy>=1.25.0 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index c47b53946b..c0eb0e7b8c 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz_base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 4b4ede3935..94b5bf69aa 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -28,9 +28,11 @@ import numpy as np import xarray -from arviz import InferenceData, concat, rcParams -from arviz.data.base import CoordSpec, DimSpec, requires +from arviz import InferenceData, concat +from arviz.data.base import CoordSpec, DimSpec from arviz_base import dict_to_dataset +from arviz_base.base import requires +from arviz_base.rcparams import RcParams from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable from rich.progress import Console @@ -212,7 +214,7 @@ def __init__( save_warmup: bool | None = None, include_transformed: bool = False, ): - self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup + self.save_warmup = RcParams["data.save_warmup"] if save_warmup is None else save_warmup self.include_transformed = include_transformed self.trace = trace diff --git a/requirements-dev.txt b/requirements-dev.txt index 6fca613919..007b692c6c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. -arviz-base +arviz-base>=0.7.0 arviz>=0.13.0 cachetools>=4.2.1 cloudpickle