Skip to content

[WIP DO NOT REVIEW] "Fold" operations to generalize Reductions#2307

Closed
jacobhinkle wants to merge 22 commits intomainfrom
fold_ops
Closed

[WIP DO NOT REVIEW] "Fold" operations to generalize Reductions#2307
jacobhinkle wants to merge 22 commits intomainfrom
fold_ops

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented May 28, 2024

This PR is an experiment in generalizing ReductionOp to handle a wider class of operations than those in BinaryOpType. The approach can be summarized as:

  1. Introduce a new IterType::Fold which represents dimensions that are "being reduced". These IterDomains must always be inlined with one another. The term "fold" was chosen to not conflict with the existing "reduction" terminology.
  2. Introduce new IR nodes representing begin and end of a fold operation. A "fold group" is defined as all the ops between these two nodes.
  3. When fold groups are finalized, the output tensors have IterType::Reduction dimensions.
  4. During lowering, translate these nodes into assignments using kir::Assign nodes which allow us to reassign a variable. This lets us update accumulation tensors inside a loop, for example.

The goals of this design are:

  1. Enable non-trivial reductions like online softmax.
  2. Enable nested reductions for cases that cannot easily be written as rfactors. The outer loop of FlashAttention1 is a good example of such a case.
  3. Avoid reinventing the complicated machinery in lowering like inlining semantics and indexing as much as possible.
  4. If there is a non-awkward way to implement scan in this setting, do so but only as a secondary consideration. The current implementation seeks to do this by allowing both scan and reduction outputs when finalizing a fold group.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant