diff --git a/pymc/model.py b/pymc/model.py index d8574ffcd8..a280e4c6fe 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1149,7 +1149,7 @@ def add_coord( length = len(values) if not isinstance(length, Variable): if mutable: - length = aesara.shared(length) + length = aesara.shared(length, name=name) else: length = aesara.tensor.constant(length) self._dim_lengths[name] = length diff --git a/pymc/sampling.py b/pymc/sampling.py index 93055fd725..a3f4dcb0fb 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -1621,6 +1621,8 @@ def compile_forward_sampling_function( vars_in_trace: List[Variable], basic_rvs: Optional[List[Variable]] = None, givens_dict: Optional[Dict[Variable, Any]] = None, + constant_data: Optional[Dict[str, np.ndarray]] = None, + constant_coords: Optional[Set[str]] = None, **kwargs, ) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]: """Compile a function to draw samples, conditioned on the values of some variables. @@ -1634,18 +1636,18 @@ def compile_forward_sampling_function( compiled function or after inference has been run. These variables are: - Variables in the outputs list - - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable`` + - ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time - Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list - Variables that are keys in the ``givens_dict`` - Variables that have volatile inputs Concretely, this function can be used to compile a function to sample from the posterior predictive distribution of a model that has variables that are conditioned - on ``MutableData`` instances. The variables that depend on the mutable data will be - considered volatile, and as such, they wont be included as inputs into the compiled function. - This means that if they have values stored in the posterior, these values will be ignored - and new values will be computed (in the case of deterministics and potentials) or sampled - (in the case of random variables). + on ``MutableData`` instances. The variables that depend on the mutable data that have changed + will be considered volatile, and as such, they wont be included as inputs into the compiled + function. This means that if they have values stored in the posterior, these values will be + ignored and new values will be computed (in the case of deterministics and potentials) or + sampled (in the case of random variables). This function also enables a way to impute values for any variable in the computational graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used @@ -1672,6 +1674,25 @@ def compile_forward_sampling_function( A dictionary that maps tensor variables to the values that should be used to replace them in the compiled function. The types of the key and value should match or an error will be raised during compilation. + constant_data : Optional[Dict[str, numpy.ndarray]] + A dictionary that maps the names of ``MutableData`` or ``ConstantData`` instances to their + corresponding values at inference time. If a model was created with ``MutableData``, these + are stored as ``SharedVariable`` with the name of the data variable and a value equal to + the initial data. At inference time, this information is stored in ``InferenceData`` + objects under the ``constant_data`` group, which allows us to check whether a + ``SharedVariable`` instance changed its values after inference or not. If the values have + changed, then the ``SharedVariable`` is assumed to be volatile. If it has not changed, then + the ``SharedVariable`` is assumed to not be volatile. If a ``SharedVariable`` is not found + in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile. + Setting ``constant_data`` to ``None`` is equivalent to passing an empty dictionary. + constant_coords : Optional[Set[str]] + A set with the names of the mutable coordinates that have not changed their shape after + inference. If a model was created with mutable coordinates, these are stored as + ``SharedVariable`` with the name of the coordinate and a value equal to the length of said + coordinate. This set let's us check if a ``SharedVariable`` is a mutated coordinate, in + which case, it is considered volatile. If a ``SharedVariable`` is not found + in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile. + Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set. Returns ------- @@ -1687,6 +1708,20 @@ def compile_forward_sampling_function( if basic_rvs is None: basic_rvs = [] + if constant_data is None: + constant_data = {} + if constant_coords is None: + constant_coords = set() + + # We define a helper function to check if shared values match to an array + def shared_value_matches(var): + try: + old_array_value = constant_data[var.name] + except KeyError: + return var.name in constant_coords + current_shared_value = var.get_value(borrow=True) + return np.array_equal(old_array_value, current_shared_value) + # We need a function graph to walk the clients and propagate the volatile property fg = FunctionGraph(outputs=outputs, clone=False) @@ -1702,6 +1737,7 @@ def compile_forward_sampling_function( or ( # SharedVariables, except RandomState/Generators isinstance(node, SharedVariable) and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + and not shared_value_matches(node) ) or ( # Basic RVs that are not in the trace node in basic_rvs and node not in vars_in_trace @@ -1835,16 +1871,24 @@ def sample_posterior_predictive( idata_kwargs = {} else: idata_kwargs = idata_kwargs.copy() + constant_data: Dict[str, np.ndarray] = {} + trace_coords: Dict[str, np.ndarray] = {} if "coords" not in idata_kwargs: idata_kwargs["coords"] = {} if isinstance(trace, InferenceData): idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"]) idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"]) + _constant_data = getattr(trace, "constant_data", None) + if _constant_data is not None: + trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()}) + constant_data.update({str(k): v.data for k, v in _constant_data.items()}) + trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()}) _trace = dataset_to_point_list(trace["posterior"]) nchain, len_trace = chains_and_samples(trace) elif isinstance(trace, xarray.Dataset): idata_kwargs["coords"].setdefault("draw", trace["draw"]) idata_kwargs["coords"].setdefault("chain", trace["chain"]) + trace_coords.update({str(k): v.data for k, v in trace.coords.items()}) _trace = dataset_to_point_list(trace) nchain, len_trace = chains_and_samples(trace) elif isinstance(trace, MultiTrace): @@ -1901,6 +1945,16 @@ def sample_posterior_predictive( stacklevel=2, ) + constant_coords = set() + for dim, coord in trace_coords.items(): + current_coord = model.coords.get(dim, None) + if ( + current_coord is not None + and len(coord) == len(current_coord) + and np.all(coord == current_coord) + ): + constant_coords.add(dim) + if var_names is not None: vars_ = [model[x] for x in var_names] else: @@ -1935,6 +1989,8 @@ def sample_posterior_predictive( basic_rvs=model.basic_RVs, givens_dict=None, random_seed=random_seed, + constant_data=constant_data, + constant_coords=constant_coords, **compile_kwargs, ) sampler_fn = point_wrapper(_sampler_fn) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 2be406b1a0..0a5239736c 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -1086,33 +1086,6 @@ def test_logging_sampled_basic_rvs_posterior(self, caplog): assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [x, y, z]")] caplog.clear() - def test_logging_sampled_basic_rvs_posterior_mutable(self, caplog): - with pm.Model() as m: - x1 = pm.MutableData("x1", 0) - y1 = pm.Normal("y1", x1) - - x2 = pm.ConstantData("x2", 0) - y2 = pm.Normal("y2", x2) - - z = pm.Normal("z", y1 + y2, observed=0) - - # `y1` will be recomputed because it depends on a `MutableData` whereas `y2` won't - # This behavior might change in the future, as it is undesirable when `MutableData` - # hasn't changed since sampling - idata = az_from_dict(posterior={"y1": np.zeros(5), "y2": np.zeros(5)}) - with m: - pm.sample_posterior_predictive(idata) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y1, z]")] - caplog.clear() - - # `y1` should now be resampled regardless of whether it was in the trace or not - # as the posterior is no longer revelant! - x1.set_value(1) - with m: - pm.sample_posterior_predictive(idata) - assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y1, z]")] - caplog.clear() - def test_logging_sampled_basic_rvs_posterior_deterministic(self, caplog): with pm.Model() as m: x = pm.Normal("x") @@ -1128,6 +1101,131 @@ def test_logging_sampled_basic_rvs_posterior_deterministic(self, caplog): assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y, z]")] caplog.clear() + @staticmethod + def make_mock_model(): + rng = np.random.default_rng(seed=42) + data = rng.normal(loc=1, scale=0.2, size=(10, 3)) + with pm.Model() as model: + model.add_coord("name", ["A", "B", "C"], mutable=True) + model.add_coord("obs", list(range(10, 20)), mutable=True) + offsets = pm.MutableData("offsets", rng.normal(0, 1, size=(10,))) + a = pm.Normal("a", mu=0, sigma=1, dims=["name"]) + b = pm.Normal("b", mu=offsets, sigma=1) + mu = pm.Deterministic("mu", a + b[..., None], dims=["obs", "name"]) + sigma = pm.HalfNormal("sigma", sigma=1, dims=["name"]) + + data = pm.MutableData( + "y_obs", + data, + dims=["obs", "name"], + ) + pm.Normal("y", mu=mu, sigma=sigma, observed=data, dims=["obs", "name"]) + return model + + @pytest.fixture(scope="class") + def mock_multitrace(self): + with self.make_mock_model(): + trace = pm.sample( + draws=10, + tune=10, + chains=2, + progressbar=False, + compute_convergence_checks=False, + return_inferencedata=False, + random_seed=42, + ) + return trace + + @pytest.fixture(scope="class", params=["MultiTrace", "InferenceData", "Dataset"]) + def mock_sample_results(self, request, mock_multitrace): + kind = request.param + trace = mock_multitrace + # We rebuild the class to ensure that all dimensions, data and coords start out + # the same across params values + model = self.make_mock_model() + if kind == "MultiTrace": + return kind, trace, model + else: + idata = pm.to_inference_data( + trace, + save_warmup=False, + model=model, + log_likelihood=False, + ) + if kind == "Dataset": + return kind, idata.posterior, model + else: + return kind, idata, model + + def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results, caplog): + kind, samples, model = mock_sample_results + with model: + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + # MultiTrace will only have the actual MCMC posterior samples but no information on + # the MutableData and mutable coordinate values, so it will always assume they are volatile + # and resample their descendants + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + # InferenceData has all MCMC posterior samples and the values for both coordinates and + # data containers. This enables it to see that no data has changed and it should only + # resample the observed variable + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [y]")] + caplog.clear() + elif kind == "Dataset": + # Dataset has all MCMC posterior samples and the values of the coordinates. This + # enables it to see that the coordinates have not changed, but the MutableData is + # assumed volatile by default + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] + caplog.clear() + + original_offsets = model["offsets"].get_value() + with model: + # Changing the MutableData values. This will only be picked up by InferenceData + pm.set_data({"offsets": original_offsets + 1}) + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] + caplog.clear() + elif kind == "Dataset": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [b, y]")] + caplog.clear() + + with model: + # Changing the mutable coordinates. This will be picked up by InferenceData and Dataset + model.set_dim("name", new_length=4, coord_values=["D", "E", "F", "G"]) + pm.set_data({"offsets": original_offsets, "y_obs": np.zeros((10, 4))}) + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, sigma, y]")] + caplog.clear() + elif kind == "Dataset": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + + with model: + # Changing the mutable coordinate values, but not shape, and also changing MutableData. + # This will trigger resampling of all variables + model.set_dim("name", new_length=3, coord_values=["A", "B", "D"]) + pm.set_data({"offsets": original_offsets + 1, "y_obs": np.zeros((10, 3))}) + pm.sample_posterior_predictive(samples) + if kind == "MultiTrace": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "InferenceData": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + elif kind == "Dataset": + assert caplog.record_tuples == [("pymc", logging.INFO, "Sampling: [a, b, sigma, y]")] + caplog.clear() + @pytest.mark.xfail( reason="sample_posterior_predictive_w not refactored for v4", raises=NotImplementedError @@ -1939,6 +2037,95 @@ def test_non_random_model_variable(self): assert {i.name for i in self.get_function_inputs(f)} == set() assert {i.name for i in self.get_function_roots(f)} == set() + def test_mutable_coords_volatile(self): + rng = np.random.default_rng(seed=42) + data = rng.normal(loc=1, scale=0.2, size=(10, 3)) + with pm.Model() as model: + model.add_coord("name", ["A", "B", "C"], mutable=True) + model.add_coord("obs", list(range(10, 20)), mutable=True) + offsets = pm.MutableData("offsets", rng.normal(0, 1, size=(10,))) + a = pm.Normal("a", mu=0, sigma=1, dims=["name"]) + b = pm.Normal("b", mu=offsets, sigma=1) + mu = pm.Deterministic("mu", a + b[..., None], dims=["obs", "name"]) + sigma = pm.HalfNormal("sigma", sigma=1, dims=["name"]) + + data = pm.MutableData( + "y_obs", + data, + dims=["obs", "name"], + ) + y = pm.Normal("y", mu=mu, sigma=sigma, observed=data, dims=["obs", "name"]) + + # When no constant_data and constant_coords, all the dependent nodes will be volatile and + # resampled + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + ) + assert volatile_rvs == {y, a, b, sigma} + assert {i.name for i in self.get_function_inputs(f)} == set() + assert {i.name for i in self.get_function_roots(f)} == {"name", "obs", "offsets"} + + # When the constant data has the same values as the shared data, offsets wont be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_data={"offsets": offsets.get_value()}, + ) + assert volatile_rvs == {y, a, sigma} + assert {i.name for i in self.get_function_inputs(f)} == {"b"} + assert {i.name for i in self.get_function_roots(f)} == {"b", "name", "obs"} + + # When we declare constant_coords, the shared variables with matching names wont be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_coords={"name", "obs"}, + ) + assert volatile_rvs == {y, b} + assert {i.name for i in self.get_function_inputs(f)} == {"a", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == { + "a", + "sigma", + "name", + "obs", + "offsets", + } + + # When we have both constant_data and constant_coords, only y will be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_data={"offsets": offsets.get_value()}, + constant_coords={"name", "obs"}, + ) + assert volatile_rvs == {y} + assert {i.name for i in self.get_function_inputs(f)} == {"a", "b", "mu", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == {"mu", "sigma", "name", "obs"} + + # When constant_data has different values than the shared variable, then + # offsets will be volatile + f, volatile_rvs = compile_forward_sampling_function( + outputs=[y], + vars_in_trace=[a, b, mu, sigma], + basic_rvs=model.basic_RVs, + constant_data={"offsets": offsets.get_value() + 1}, + constant_coords={"name", "obs"}, + ) + assert volatile_rvs == {y, b} + assert {i.name for i in self.get_function_inputs(f)} == {"a", "sigma"} + assert {i.name for i in self.get_function_roots(f)} == { + "a", + "sigma", + "name", + "obs", + "offsets", + } + def test_get_seeds_per_chain(): ret = _get_seeds_per_chain(None, chains=1)