-
Notifications
You must be signed in to change notification settings - Fork 79
Description
With dynamic resize, we have implemented a very basic system of trying to prove propositions at definition time (before concretization) if we have some constant inputs; for example, if we have constant pad values that sum to 2, we know that the padded extent will be >= 2 even if we don't know the input extent. This allows us to avoid using symbolic operations in some common cases. Other ops like cat will pad using complicated expressions, so the first point of this issue is that it might be advantageous to have more sophisticated theorem proving during definition.
Currently NVFuser possesses a powerful system for simplifying expressions: simplifyExpr. It implements iterative term rewriting, starting with a provided set of assumptions that defaults to a collection of "axioms" from IrContainer::axioms() which are a collection of Bools that are known to be True: namely this attests that parallelization dimensions are positive and thread and block IDs are non-negative.
There are more axioms that could empower this system. For example we know that any fusion input extents should be non-negative. During concretization, we analyze individual ops and look at scalars, extract their values and make a determination. These decision points are then predicates that can be considered constant for that concretization. For example, if we determine that the extent of a pad op is equal to 1, that will be true for any set of inputs sharing that concretization. The second point of this issue is in the other direction: that axioms accumulated during definition and concretization might be useful during lowering.
Approaches
I think one way to enable this flexibly is to add the ability to append axioms to IrContainer by adding an attestAxiom method. Creation of a TensorView would add an axiom that each extent is non-negative. We can then use simplifyExpr directly during definition or lowering. We could then also sprinkle in more axioms to aid the system. Alternatively, we could just override axioms() for Fusion in order to also return the input extent nonnegative axioms, but this is less flexible.