From afa7bbfebae92a2919a1d99ba8a9f49701fe67de Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 21 Sep 2022 12:00:19 +0200 Subject: [PATCH 1/2] Remove unused caplog --- pymc/tests/test_sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index a004c486d2..ae200b29ad 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -673,7 +673,7 @@ def test_normal_scalar_idata(self): ppc = pm.sample_posterior_predictive(idata, keep_size=True, return_inferencedata=False) assert ppc["a"].shape == (nchains, ndraws) - def test_normal_vector(self, caplog): + def test_normal_vector(self): with pm.Model() as model: mu = pm.Normal("mu", 0.0, 1.0) a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) @@ -710,7 +710,7 @@ def test_normal_vector(self, caplog): assert "a" in ppc assert ppc["a"].shape == (12, 2) - def test_normal_vector_idata(self, caplog): + def test_normal_vector_idata(self): with pm.Model() as model: mu = pm.Normal("mu", 0.0, 1.0) a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) @@ -726,7 +726,7 @@ def test_normal_vector_idata(self, caplog): ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False, keep_size=True) assert ppc["a"].shape == (trace.nchains, len(trace), 2) - def test_exceptions(self, caplog): + def test_exceptions(self): with pm.Model() as model: mu = pm.Normal("mu", 0.0, 1.0) a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) From 6479ad449ecee556abed775d00ee31cdd550109c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 21 Sep 2022 11:59:05 +0200 Subject: [PATCH 2/2] Log sampled basic_RVs sample_*_predictive functions --- pymc/sampling.py | 48 ++++++----- pymc/tests/test_sampling.py | 154 ++++++++++++++++++++++++++++++++---- 2 files changed, 165 insertions(+), 37 deletions(-) diff --git a/pymc/sampling.py b/pymc/sampling.py index 20100de25e..93055fd725 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -1622,7 +1622,7 @@ def compile_forward_sampling_function( basic_rvs: Optional[List[Variable]] = None, givens_dict: Optional[Dict[Variable, Any]] = None, **kwargs, -) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]: +) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]: """Compile a function to draw samples, conditioned on the values of some variables. The goal of this function is to walk the aesara computational graph from the list @@ -1635,13 +1635,10 @@ def compile_forward_sampling_function( - Variables in the outputs list - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable`` - - Basic RVs that are not in the ``vars_in_trace`` list + - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list - Variables that are keys in the ``givens_dict`` - Variables that have volatile inputs - Where by basic RVs we mean ``Variable`` instances produced by a ``RandomVariable`` ``Op`` - that are in the ``basic_rvs`` list. - Concretely, this function can be used to compile a function to sample from the posterior predictive distribution of a model that has variables that are conditioned on ``MutableData`` instances. The variables that depend on the mutable data will be @@ -1670,12 +1667,19 @@ def compile_forward_sampling_function( output of ``model.basic_RVs``) should have a reference to the variables that should be considered as random variable instances. This includes variables that have a ``RandomVariable`` owner op, but also unpure random variables like Mixtures, or - Censored distributions. If ``None``, only pure random variables will be considered - as potential random variables. + Censored distributions. givens_dict : Optional[Dict[aesara.graph.basic.Variable, Any]] A dictionary that maps tensor variables to the values that should be used to replace them in the compiled function. The types of the key and value should match or an error will be raised during compilation. + + Returns + ------- + function: Callable + Compiled forward sampling Aesara function + volatile_basic_rvs: Set of Variable + Set of all basic_rvs that were considered volatile and will be resampled when + the function is evaluated """ if givens_dict is None: givens_dict = {} @@ -1741,7 +1745,10 @@ def expand(node): for node, value in givens_dict.items() ] - return compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs) + return ( + compile_pymc(inputs, fg.outputs, givens=givens, on_unused_input="ignore", **kwargs), + set(basic_rvs) & (volatile_nodes - set(givens_dict)), # Basic RVs that will be resampled + ) def sample_posterior_predictive( @@ -1900,7 +1907,6 @@ def sample_posterior_predictive( vars_ = model.observed_RVs + model.auto_deterministics indices = np.arange(samples) - if progressbar: indices = progress_bar(indices, total=samples, display=progressbar) @@ -1923,17 +1929,17 @@ def sample_posterior_predictive( compile_kwargs.setdefault("allow_input_downcast", True) compile_kwargs.setdefault("accept_inplace", True) - sampler_fn = point_wrapper( - compile_forward_sampling_function( - outputs=vars_to_sample, - vars_in_trace=vars_in_trace, - basic_rvs=model.basic_RVs, - givens_dict=None, - random_seed=random_seed, - **compile_kwargs, - ) + _sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( + outputs=vars_to_sample, + vars_in_trace=vars_in_trace, + basic_rvs=model.basic_RVs, + givens_dict=None, + random_seed=random_seed, + **compile_kwargs, ) - + sampler_fn = point_wrapper(_sampler_fn) + # All model variables have a name, but mypy does not know this + _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: if isinstance(_trace, MultiTrace): @@ -2242,7 +2248,7 @@ def sample_prior_predictive( compile_kwargs.setdefault("allow_input_downcast", True) compile_kwargs.setdefault("accept_inplace", True) - sampler_fn = compile_forward_sampling_function( + sampler_fn, volatile_basic_rvs = compile_forward_sampling_function( vars_to_sample, vars_in_trace=[], basic_rvs=model.basic_RVs, @@ -2251,6 +2257,8 @@ def sample_prior_predictive( **compile_kwargs, ) + # All model variables have a name, but mypy does not know this + _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore values = zip(*(sampler_fn() for i in range(samples))) data = {k: np.stack(v) for k, v in zip(names, values)} diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index ae200b29ad..b88663e8c7 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import re import unittest.mock as mock import warnings @@ -1023,6 +1024,110 @@ def test_deterministics_out_of_idata(self, multitrace): ppc = pm.sample_posterior_predictive(trace, var_names="c", return_inferencedata=True) assert np.all(np.abs(ppc.posterior_predictive.c + 4) <= 0.1) + def test_logging_sampled_basic_rvs_prior(self, caplog): + with pm.Model() as m: + x = pm.Normal("x") + y = pm.Deterministic("y", x + 1) + z = pm.Normal("z", y, observed=0) + + with m: + pm.sample_prior_predictive(samples=1) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, z]")] + caplog.clear() + + with m: + pm.sample_prior_predictive(samples=1, var_names=["x"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x]")] + caplog.clear() + + def test_logging_sampled_basic_rvs_posterior(self, caplog): + with pm.Model() as m: + x = pm.Normal("x") + x_det = pm.Deterministic("x_det", x + 1) + y = pm.Normal("y", x_det) + z = pm.Normal("z", y, observed=0) + + idata = az_from_dict(posterior={"x": np.zeros(5), "x_det": np.ones(5), "y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [z]")] + caplog.clear() + + with m: + pm.sample_posterior_predictive(idata, var_names=["y", "z"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] + caplog.clear() + + # Resampling `x` will force resampling of `y`, even if it is in trace + with m: + pm.sample_posterior_predictive(idata, var_names=["x", "z"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, y, z]")] + caplog.clear() + + # Missing deterministic `x_det` does not show in the log, even if it is being + # recomputed, only `y` RV shows + idata = az_from_dict(posterior={"x": np.zeros(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] + caplog.clear() + + # Missing deterministic `x_det` does not cause recomputation of downstream `y` RV + idata = az_from_dict(posterior={"x": np.zeros(5), "y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [z]")] + caplog.clear() + + # Missing `x` causes sampling of downstream `y` RV, even if it is present in trace + idata = az_from_dict(posterior={"y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, y, z]")] + caplog.clear() + + def test_logging_sampled_basic_rvs_posterior_mutable(self, caplog): + with pm.Model() as m: + x1 = pm.MutableData("x1", 0) + y1 = pm.Normal("y1", x1) + + x2 = pm.ConstantData("x2", 0) + y2 = pm.Normal("y2", x2) + + z = pm.Normal("z", y1 + y2, observed=0) + + # `y1` will be recomputed because it depends on a `MutableData` whereas `y2` won't + # This behavior might change in the future, as it is undesirable when `MutableData` + # hasn't changed since sampling + idata = az_from_dict(posterior={"y1": np.zeros(5), "y2": np.zeros(5)}) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y1, z]")] + caplog.clear() + + # `y1` should now be resampled regardless of whether it was in the trace or not + # as the posterior is no longer revelant! + x1.set_value(1) + with m: + pm.sample_posterior_predictive(idata) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y1, z]")] + caplog.clear() + + def test_logging_sampled_basic_rvs_posterior_deterministic(self, caplog): + with pm.Model() as m: + x = pm.Normal("x") + x_det = pm.Deterministic("x_det", x + 1) + y = pm.Normal("y", x_det) + z = pm.Normal("z", y, observed=0) + + # Explicit resampling a deterministic will lead to resampling of downstream RV `y` + # This behavior could change in the future as the posterior of `y` is still valid + idata = az_from_dict(posterior={"x": np.zeros(5), "x_det": np.ones(5), "y": np.ones(5)}) + with m: + pm.sample_posterior_predictive(idata, var_names=["x_det", "z"]) + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] + caplog.clear() + @pytest.mark.xfail( reason="sample_posterior_predictive_w not refactored for v4", raises=NotImplementedError @@ -1621,11 +1726,12 @@ def test_linear_model(self): sigma = pm.HalfNormal("sigma", 0.1) obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape) - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( [obs], vars_in_trace=[alpha, beta, sigma, mu], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {obs} assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma"} assert {i.name for i in self.get_function_roots(f)} == {"x", "alpha", "beta", "sigma"} @@ -1639,11 +1745,12 @@ def test_linear_model(self): sigma = pm.HalfNormal("sigma", 0.1) obs = pm.Normal("obs", mu, sigma, observed=y, shape=x.shape) - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( [obs], vars_in_trace=[alpha, beta, sigma, mu], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {obs} assert {i.name for i in self.get_function_inputs(f)} == {"alpha", "beta", "sigma", "mu"} assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} @@ -1657,22 +1764,24 @@ def test_nested_observed_model(self): beta = pm.Normal("beta", 0, 0.1, size=p.shape) mu = pm.Deterministic("mu", beta[category]) sigma = pm.HalfNormal("sigma", 0.1) - pm.Normal("obs", mu, sigma, observed=y, shape=mu.shape) + obs = pm.Normal("obs", mu, sigma, observed=y, shape=mu.shape) - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=model.observed_RVs, vars_in_trace=[beta, mu, sigma], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {category, obs} assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} assert {i.name for i in self.get_function_roots(f)} == {"x", "p", "beta", "sigma"} - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=model.observed_RVs, vars_in_trace=[beta, mu, sigma], basic_rvs=model.basic_RVs, givens_dict={category: np.zeros(10, dtype=category.dtype)}, ) + assert volatile_rvs == {obs} assert {i.name for i in self.get_function_inputs(f)} == {"beta", "sigma"} assert {i.name for i in self.get_function_roots(f)} == { "x", @@ -1688,17 +1797,18 @@ def test_volatile_parameters(self): mu = pm.Normal("mu", 0, 1) nested_mu = pm.Normal("nested_mu", mu, 1, size=10) sigma = pm.HalfNormal("sigma", 1) - pm.Normal("obs", nested_mu, sigma, observed=y, shape=nested_mu.shape) + obs = pm.Normal("obs", nested_mu, sigma, observed=y, shape=nested_mu.shape) - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=model.observed_RVs, vars_in_trace=[nested_mu, sigma], # mu isn't in the trace and will be deemed volatile basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {mu, nested_mu, obs} assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} assert {i.name for i in self.get_function_roots(f)} == {"sigma"} - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=model.observed_RVs, vars_in_trace=[mu, nested_mu, sigma], basic_rvs=model.basic_RVs, @@ -1706,6 +1816,7 @@ def test_volatile_parameters(self): mu: np.array(1.0) }, # mu will be considered volatile because it's in givens ) + assert volatile_rvs == {nested_mu, obs} assert {i.name for i in self.get_function_inputs(f)} == {"sigma"} assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma"} @@ -1719,27 +1830,30 @@ def test_mixture(self): mix_mu = pm.Mixture("mix_mu", w=w, comp_dists=components) obs = pm.Normal("obs", mix_mu, 1, observed=np.ones((5, 3))) - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[mix_mu, mu, w], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {obs} assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu", "mix_mu"} assert {i.name for i in self.get_function_roots(f)} == {"mix_mu"} - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[mu, w], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {mix_mu, obs} assert {i.name for i in self.get_function_inputs(f)} == {"w", "mu"} assert {i.name for i in self.get_function_roots(f)} == {"w", "mu"} - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[mix_mu, mu], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {w, mix_mu, obs} assert {i.name for i in self.get_function_inputs(f)} == {"mu"} assert {i.name for i in self.get_function_roots(f)} == {"mu"} @@ -1749,19 +1863,21 @@ def test_censored(self): mu = pm.Censored("mu", pm.Normal.dist(mu=latent_mu, sigma=1), lower=-1, upper=1) obs = pm.Normal("obs", mu, 1, observed=np.ones((10, 3))) - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[latent_mu, mu], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {obs} assert {i.name for i in self.get_function_inputs(f)} == {"latent_mu", "mu"} assert {i.name for i in self.get_function_roots(f)} == {"mu"} - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[mu], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {latent_mu, mu, obs} assert {i.name for i in self.get_function_inputs(f)} == set() assert {i.name for i in self.get_function_roots(f)} == set() @@ -1776,27 +1892,30 @@ def test_lkj_cholesky_cov(self): chol = pm.Deterministic("chol", chol) obs = pm.MvNormal("obs", mu=mu, chol=chol, observed=np.zeros(3)) - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[chol_packed, chol], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {obs} assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed", "chol"} assert {i.name for i in self.get_function_roots(f)} == {"chol"} - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[chol_packed], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {obs} assert {i.name for i in self.get_function_inputs(f)} == {"chol_packed"} assert {i.name for i in self.get_function_roots(f)} == {"chol_packed"} - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[chol], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {chol_packed, obs} assert {i.name for i in self.get_function_inputs(f)} == set() assert {i.name for i in self.get_function_roots(f)} == set() @@ -1811,11 +1930,12 @@ def test_non_random_model_variable(self): obs = pm.Normal("obs", y_abs, observed=np.zeros(10)) # y_abs should be resampled even if in the trace, because the source y is missing - f = compile_forward_sampling_function( + f, volatile_rvs = compile_forward_sampling_function( outputs=[obs], vars_in_trace=[y_abs], basic_rvs=model.basic_RVs, ) + assert volatile_rvs == {y, obs} assert {i.name for i in self.get_function_inputs(f)} == set() assert {i.name for i in self.get_function_roots(f)} == set()