-
Notifications
You must be signed in to change notification settings - Fork 79
Description
Proposal
Add a concretization-like step to mutate TensorView definition with correct stride_order (allocationDomain) and contiguity flag.
Context
currently nvfuser API requires user to statically define broadcast dimensions as well as stride order (not even in our integration yet!) at the time when a fusion is defined. Any further update to the tensor shape / stride that contradicts these definition requires user/integration to handle that with a custom cache strategy.
It is challenging to get a custom cache right, because of the complexity involved. The caching strategy in nvfuser evolves over time, which makes it even harder to maintain.
We should try to keep this complexity inside nvfuser and abstract away such definition on user API.
i.e. when we define a TensorView, we can specify only the rank of a tensor, without specifying the stride order or broadcast semantics. At runtime when a concrete input tensor is given, we can concretize the definition of each tensor, propagate that information to have a concrete fusion that can be scheduled.
Eventually we would want to use this to handle both stirde_order as well as shape specialization (broadcast / size-0 or whatever new thing that pops up).
TODO
- We can take the initial step to support stride_order, since it's simpler to add and doesn't require propagation. It works very much like the reshape handling in concretization and is somewhat isolate from the rest of the cache system we have in our stack.
- The protocol for how to interpret a stride-order can be tricky and that propagation (if we want to do) is even harder. (for some fun context, look at these: Memory Format support for Resnet models pytorch/pytorch#23403
suggest_memory_formathas ambiguity & cannot represent intended layout format for corner cases pytorch/pytorch#24090) But luckily once those are inside codegen logic, changing them would be relatively easy. So we don't have to make the right decision in the first place.