Skip to content

Add recipes.forward_filter_backward_rsample()#549

Merged
eb8680 merged 21 commits into
masterfrom
auto-gaussian
Sep 22, 2021
Merged

Add recipes.forward_filter_backward_rsample()#549
eb8680 merged 21 commits into
masterfrom
auto-gaussian

Conversation

@fritzo
Copy link
Copy Markdown
Member

@fritzo fritzo commented Sep 17, 2021

Addresses pyro-ppl/pyro#2929

This implements a multi-sample forward_filter_backward_rsample() for use in Pyro's AutoGaussian guide.

Changes

  • Modifying the semantics of MonteCarlo. These changes preserve semantics of single-sample MonteCarlo() but change semantics of MonteCarlo(particles=Bint[n]) from mean-reducing over particles to introducing a new batch dimension over particles.
  • Changes semantics of Funsor.sample() to avoid scaling by numel(sampled_inputs). Correspondingly .unscaled_sample() is renamed to ._sample().
  • Supporting .reduce(ops.mean, ...). This breaks from .reduce(op, ...) supporting only associative ops, but this does seem like the cleanest syntax to support ReductionOps over discrete input variables, which will be an important pattern now that the 1/numel scaling os no longer performed by .sample().
  • Adding a new funsor.recipes module with high-level algorithms intended for use in both Pyro and NumPyro. The idea so to maximize test sharing of these recipes by testing all backends in the funsor repo.
  • Adding a batch_vars arg to AdjointTape to support batched backward sample (this might be simplified by Constant Funsor #548)
  • Factoring out a new forward_sample() function from adjoint().

Tested

@fritzo fritzo added the WIP label Sep 17, 2021
@fritzo fritzo mentioned this pull request Sep 21, 2021
@fritzo fritzo changed the title Change semantics of multi-sample MonteCarlo to support AutoGaussian Add recipes.forward_filter_backward_rsample() Sep 22, 2021
@fritzo fritzo requested a review from eb8680 September 22, 2021 16:00
@fritzo fritzo marked this pull request as ready for review September 22, 2021 16:00
@fritzo fritzo added the enhancement New feature or request label Sep 22, 2021
@eb8680
Copy link
Copy Markdown
Member

eb8680 commented Sep 22, 2021

All of the relevant downstream tests in Numpyro pass under these changes (on my local machine, at least), but the changes to the behavior of .sample() seem to have broken a bunch of downstream tests in pyro.contrib.funsor.

Since we don't pin Pyro dev to Funsor master, one way to proceed would be for me to merge this and then put up a Pyro pull request with fixes that we can merge on the next Funsor or Pyro release. Does that sound reasonable?

Copy link
Copy Markdown
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good per Zoom reviews, just a few nits.

Comment thread funsor/terms.py Outdated
Comment thread funsor/terms.py
assert isinstance(reduced_vars, frozenset), reduced_vars

# Attempt to convert ReductionOp to AssociativeOp.
if isinstance(op, ops.ReductionOp):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this standardization logic to run under all interpretations?

Copy link
Copy Markdown
Member Author

@fritzo fritzo Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I see .reduce() as lightweight syntax around deeper reinterpretable syntax including Reduce. Note it is nonsensical to create a lazy Reduce of a ReductionOp, the way we have defined Reduce.

Comment thread test/test_approximations.py Outdated
actual = q1 + s * q2
assert_close(actual, expected)

if approximate not in (monte_carlo, monte_carlo_10):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this will report the test as passing under monte_carlo, would it be better to keep the xfail status?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just removed these test as I don't understand what they do.

Comment thread test/test_approximations.py Outdated
actual = q1 + s * q2
assert_close(actual, expected)

if approximate not in (monte_carlo, monte_carlo_10):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: should this xfail rather than pass?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, I was going to strengthen these tests once we fixed adjoint, will do...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

Comment thread test/test_recipes.py
def check_ffbr(factors, eliminate, plates, actual_samples, actual_log_prob):
"""
This can be seen as performing naive tensor variable elimination by
breaking all plates and creating a single flat joint distribution.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this actually turns out to be true, at some point we should move this logic into funsor.sum_product and test it against existing implementations since it's somewhat complex for testing logic and it is used to check correctness of what will presumably end up as one of our most important code paths if/when AutoGaussian is finished.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, it seems reasonable to eventually factor this out as a naive_sum_product or sth. Before we refactor, we should fix the intractable tests below, as that will involve more lambdas etc.

@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Sep 22, 2021

one way to proceed would be for me to merge this and then put up a Pyro pull request with fixes

I think we can safely make the following sequence of changes:

  1. merge this PR
  2. fix Pyro dev and pin to a particular Funsor commit
  3. ...add AutoGaussian features via many commits to Funsor and Pyro...
  4. release Funsor
  5. pin Pyro to the Funsor release (required for Pyro releases)
  6. release Pyro

@eb8680
Copy link
Copy Markdown
Member

eb8680 commented Sep 22, 2021

That sounds reasonable, provided the fixes necessary in pyro.contrib.funsor aren't too onerous - I don't want to block AutoGaussian on that if it can be avoided

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting response enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants