diff --git a/csrc/preseg_passes/move_split_cat.cpp b/csrc/preseg_passes/move_split_cat.cpp index 1764e7be52f..831a7739f40 100644 --- a/csrc/preseg_passes/move_split_cat.cpp +++ b/csrc/preseg_passes/move_split_cat.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -21,9 +22,52 @@ namespace nvfuser::preseg_passes { namespace { -// Returns true when Exprs in the frontier can be horizontally merged and -// applied on the unsplit tensor. -bool horizontallyMergeable( +class CancelSplitCat { + public: + CancelSplitCat(Fusion* fusion) + : fusion_(fusion), + id_model_(fusion, /*build_graphs=*/true, /*allow_self_mapping=*/true) {} + + // Finds all cancellable pairs, cancels them and horizontallly + // merges ops in between. + void run(); + + private: + // Returns true when Exprs in the frontier can be horizontally merged and + // applied on the unsplit tensor. + bool horizontallyMergeable( + const std::vector& frontier, + int64_t& split_axis); + + // Finds the canceling split of `cat` and returns the input TensorView of the + // split. A split (implemented as multiple `slice`s) and a cat cancel when + // they work on the same dimension. For example, when + // + // s0 = in[:, :5] + // s1 = in[:, 5:] + // out = cat([s0, s1], dim=-1) + // + // findCancelingSplit(out) returns `in`. + // + // `cat` doesn't have to immediately follow the split. For example, when + // + // s0 = in[:, :5] + // s1 = in[:, 5:] + // t0 = permute(s0) + // t1 = permute(s1) + // out = cat([t0, t1], dim=0) + // + // In addition to returning `in`, findCancelingSplit(out) puts `t0`'s defining + // `permute` into `use_def_chain` so the caller can reconstruct `out` by + // replaying `use_def_chain` (in reverse order) on `in`. + TensorView* findCancelingSplit(CatOp* cat, std::vector& use_def_chain); + + Fusion* fusion_; + + IdModel id_model_; +}; + +bool CancelSplitCat::horizontallyMergeable( const std::vector& frontier, int64_t& split_axis) { NVF_ERROR(!frontier.empty()); @@ -105,28 +149,9 @@ std::pair, int64_t> getCatInputsAndAxis(CatOp* cat) { return {pads, cat_axis}; } -// Finds the canceling split of `cat` and returns the input TensorView of the -// split. A split (implemented as multiple `slice`s) and a cat cancel when they -// work on the same dimension. For example, when -// -// s0 = in[:, :5] -// s1 = in[:, 5:] -// out = cat([s0, s1], dim=-1) -// -// findCancelingSplit(out) returns `in`. -// -// `cat` doesn't have to immediately follow the split. For example, when -// -// s0 = in[:, :5] -// s1 = in[:, 5:] -// t0 = permute(s0) -// t1 = permute(s1) -// out = cat([t0, t1], dim=0) -// -// In addition to returning `in`, findCancelingSplit(out) puts `t0`'s defining -// `permute` into `use_def_chain` so the caller can reconstruct `out` by -// replaying `use_def_chain` (in reverse order) on `in`. -TensorView* findCancelingSplit(CatOp* cat, std::vector& use_def_chain) { +TensorView* CancelSplitCat::findCancelingSplit( + CatOp* cat, + std::vector& use_def_chain) { NVF_CHECK(!cat->inputs().empty(), "`cat` has zero inputs: ", cat); auto [pads, cat_axis] = getCatInputsAndAxis(cat); @@ -228,10 +253,8 @@ TensorView* findCancelingSplit(CatOp* cat, std::vector& use_def_chain) { return split_in; } -} // namespace - -void MoveSplitCatPass::runPass(Fusion* fusion) { - std::vector exprs = fusion->exprs(); +void CancelSplitCat::run() { + std::vector exprs = fusion_->exprs(); for (auto* cat : ir_utils::filterByType(exprs)) { std::vector use_def_chain; TensorView* split_in = findCancelingSplit(cat, std::ref(use_def_chain)); @@ -261,4 +284,10 @@ void MoveSplitCatPass::runPass(Fusion* fusion) { } } +} // namespace + +void MoveSplitCatPass::runPass(Fusion* fusion) { + CancelSplitCat(fusion).run(); +} + } // namespace nvfuser::preseg_passes