Add recipes.forward_filter_backward_rsample()#549
Conversation
|
All of the relevant downstream tests in Numpyro pass under these changes (on my local machine, at least), but the changes to the behavior of Since we don't pin Pyro |
eb8680
left a comment
There was a problem hiding this comment.
Looks good per Zoom reviews, just a few nits.
| assert isinstance(reduced_vars, frozenset), reduced_vars | ||
|
|
||
| # Attempt to convert ReductionOp to AssociativeOp. | ||
| if isinstance(op, ops.ReductionOp): |
There was a problem hiding this comment.
Do we want this standardization logic to run under all interpretations?
There was a problem hiding this comment.
yes, I see .reduce() as lightweight syntax around deeper reinterpretable syntax including Reduce. Note it is nonsensical to create a lazy Reduce of a ReductionOp, the way we have defined Reduce.
| actual = q1 + s * q2 | ||
| assert_close(actual, expected) | ||
|
|
||
| if approximate not in (monte_carlo, monte_carlo_10): |
There was a problem hiding this comment.
Nit: this will report the test as passing under monte_carlo, would it be better to keep the xfail status?
There was a problem hiding this comment.
I've just removed these test as I don't understand what they do.
| actual = q1 + s * q2 | ||
| assert_close(actual, expected) | ||
|
|
||
| if approximate not in (monte_carlo, monte_carlo_10): |
There was a problem hiding this comment.
ditto: should this xfail rather than pass?
There was a problem hiding this comment.
oh right, I was going to strengthen these tests once we fixed adjoint, will do...
| def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob): | ||
| """ | ||
| This can be seen as performing naive tensor variable elimination by | ||
| breaking all plates and creating a single flat joint distribution. |
There was a problem hiding this comment.
If this actually turns out to be true, at some point we should move this logic into funsor.sum_product and test it against existing implementations since it's somewhat complex for testing logic and it is used to check correctness of what will presumably end up as one of our most important code paths if/when AutoGaussian is finished.
There was a problem hiding this comment.
SGTM, it seems reasonable to eventually factor this out as a naive_sum_product or sth. Before we refactor, we should fix the intractable tests below, as that will involve more lambdas etc.
I think we can safely make the following sequence of changes:
|
|
That sounds reasonable, provided the fixes necessary in |
Addresses pyro-ppl/pyro#2929
This implements a multi-sample
forward_filter_backward_rsample()for use in Pyro'sAutoGaussianguide.Changes
MonteCarlo. These changes preserve semantics of single-sampleMonteCarlo()but change semantics ofMonteCarlo(particles=Bint[n])from mean-reducing overparticlesto introducing a new batch dimension overparticles.Funsor.sample()to avoid scaling by numel(sampled_inputs). Correspondingly.unscaled_sample()is renamed to._sample()..reduce(ops.mean, ...). This breaks from.reduce(op, ...)supporting only associative ops, but this does seem like the cleanest syntax to supportReductionOps over discrete input variables, which will be an important pattern now that the1/numelscaling os no longer performed by.sample().funsor.recipesmodule with high-level algorithms intended for use in both Pyro and NumPyro. The idea so to maximize test sharing of these recipes by testing all backends in the funsor repo.batch_varsarg toAdjointTapeto support batched backward sample (this might be simplified by Constant Funsor #548)forward_sample()function fromadjoint().Tested