You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This aims to work around backend jit issues and improve speed of Funsor used in Pyro. The approach is to first perform symbolic computations among Funsors, then lower to simple Funsor expressions, then compile the lowered Funsor expression to a straight-line program involving only backend ops, then optionally convert to Python code. The final program depends only on funsor.ops. This eliminates interpreter overhead, but does not eliminate op dispatch overhead.
The immediate application is to speed up Pyro's AutoGaussian guide pyro-ppl/pyro#2929, but this will also require symbolic Gaussians #556.
Tasks
compile lowered funsors to Funsor-free programs
lower Contraction to Binary
trace simple op graphs
support tuples in compile_function() for functions with multiple outputs (e.g. forward filter backward precondition).
Tasks deferred to follow-up PRs
lower bound variables, maybe via Lambda or ops.einsum? E.g. to eliminate the bound variable i:
Looks great, and definitely feasible in some form. What about tracing individual funsor.Ops rather than Funsor terms/rewrite rules? The interface might a bit more awkward, but we could have a very low-overhead untyped graph representation that way.
I guess that approach would require strict use of funsor.ops for every eager computation in every pattern, which is probably not realistic or desirable.
🤔 Interesting, I guess that would avoid both a lowering stage and the need to refactor Gaussian to be symbolic.
strict use of funsor.ops for every eager computation
Yes, I guess we'd need to manually desugar all math e.g. x + y --> ops.add(x, y), even in Gaussian, which IMO would make math less maintainable. It might actually be easier to refactor Gaussian to be symbolic. I wonder if funsor.syntax could help, since it desugars to use funsor.ops 😬
We could also experiment with torch.fx for even lower level tracing, which would save us having to change a bunch of Funsor code. That would make this module specific to the PyTorch backend for now, but that doesn't seem so bad since it's currently our only use case.
Refactoring Gaussian seems worthwhile regardless of how we go about this, but would be lower priority.
@eb8680 WDYT about merging this partial implementation (w/o lower() support for bound variables), and simultaneously working on all the solutions you proposed: this funsor.compiler (this PR and subsequent pair coding sessions), an op-level funsor.ops.trace module, and torch.fx (would this be a backend?)?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Addresses pyro-ppl/pyro#2929
pair coded with @eb8680
This aims to work around backend jit issues and improve speed of Funsor used in Pyro. The approach is to first perform symbolic computations among Funsors, then lower to simple Funsor expressions, then compile the lowered Funsor expression to a straight-line program involving only backend ops, then optionally convert to Python code. The final program depends only on
funsor.ops. This eliminates interpreter overhead, but does not eliminate op dispatch overhead.The immediate application is to speed up Pyro's
AutoGaussianguide pyro-ppl/pyro#2929, but this will also require symbolic Gaussians #556.Tasks
ContractiontoBinarycompile_function()for functions with multiple outputs (e.g. forward filter backward precondition).Tasks deferred to follow-up PRs
Lambdaorops.einsum? E.g. to eliminate the bound variablei:funsor.opsinternally (notopt_einsumwith direct backends), or makeopt_einsuman op or sthfunsor.opsinternallytrace_function()to support multiple outputsTested
compile_funsor()trace_function()