Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 58 additions & 29 deletions csrc/preseg_passes/move_split_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <expr_simplifier.h>
#include <fusion.h>
#include <id_model/id_model.h>
#include <ir/builder.h>
#include <ir/interface_nodes.h>
#include <ir/internal_base_nodes.h>
Expand All @@ -21,9 +22,52 @@ namespace nvfuser::preseg_passes {

namespace {

// Returns true when Exprs in the frontier can be horizontally merged and
// applied on the unsplit tensor.
bool horizontallyMergeable(
class CancelSplitCat {
public:
CancelSplitCat(Fusion* fusion)
: fusion_(fusion),
id_model_(fusion, /*build_graphs=*/true, /*allow_self_mapping=*/true) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to consider not building all the graphs by default. If all we need is just the exact graph, we can skip generating the other graphs, which may be much more costly than just building the exact graph.

Copy link
Collaborator Author

@wujingyue wujingyue Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to leave that part in the follow-up PR.


// Finds all cancellable <split,cat> pairs, cancels them and horizontallly
// merges ops in between.
void run();

private:
// Returns true when Exprs in the frontier can be horizontally merged and
// applied on the unsplit tensor.
bool horizontallyMergeable(
const std::vector<Expr*>& frontier,
int64_t& split_axis);

// Finds the canceling split of `cat` and returns the input TensorView of the
// split. A split (implemented as multiple `slice`s) and a cat cancel when
// they work on the same dimension. For example, when
//
// s0 = in[:, :5]
// s1 = in[:, 5:]
// out = cat([s0, s1], dim=-1)
//
// findCancelingSplit(out) returns `in`.
//
// `cat` doesn't have to immediately follow the split. For example, when
//
// s0 = in[:, :5]
// s1 = in[:, 5:]
// t0 = permute(s0)
// t1 = permute(s1)
// out = cat([t0, t1], dim=0)
//
// In addition to returning `in`, findCancelingSplit(out) puts `t0`'s defining
// `permute` into `use_def_chain` so the caller can reconstruct `out` by
// replaying `use_def_chain` (in reverse order) on `in`.
TensorView* findCancelingSplit(CatOp* cat, std::vector<Expr*>& use_def_chain);

Fusion* fusion_;

IdModel id_model_;
};

bool CancelSplitCat::horizontallyMergeable(
const std::vector<Expr*>& frontier,
int64_t& split_axis) {
NVF_ERROR(!frontier.empty());
Expand Down Expand Up @@ -105,28 +149,9 @@ std::pair<std::vector<PadOp*>, int64_t> getCatInputsAndAxis(CatOp* cat) {
return {pads, cat_axis};
}

// Finds the canceling split of `cat` and returns the input TensorView of the
// split. A split (implemented as multiple `slice`s) and a cat cancel when they
// work on the same dimension. For example, when
//
// s0 = in[:, :5]
// s1 = in[:, 5:]
// out = cat([s0, s1], dim=-1)
//
// findCancelingSplit(out) returns `in`.
//
// `cat` doesn't have to immediately follow the split. For example, when
//
// s0 = in[:, :5]
// s1 = in[:, 5:]
// t0 = permute(s0)
// t1 = permute(s1)
// out = cat([t0, t1], dim=0)
//
// In addition to returning `in`, findCancelingSplit(out) puts `t0`'s defining
// `permute` into `use_def_chain` so the caller can reconstruct `out` by
// replaying `use_def_chain` (in reverse order) on `in`.
TensorView* findCancelingSplit(CatOp* cat, std::vector<Expr*>& use_def_chain) {
TensorView* CancelSplitCat::findCancelingSplit(
CatOp* cat,
std::vector<Expr*>& use_def_chain) {
NVF_CHECK(!cat->inputs().empty(), "`cat` has zero inputs: ", cat);

auto [pads, cat_axis] = getCatInputsAndAxis(cat);
Expand Down Expand Up @@ -228,10 +253,8 @@ TensorView* findCancelingSplit(CatOp* cat, std::vector<Expr*>& use_def_chain) {
return split_in;
}

} // namespace

void MoveSplitCatPass::runPass(Fusion* fusion) {
std::vector<Expr*> exprs = fusion->exprs();
void CancelSplitCat::run() {
std::vector<Expr*> exprs = fusion_->exprs();
for (auto* cat : ir_utils::filterByType<CatOp>(exprs)) {
std::vector<Expr*> use_def_chain;
TensorView* split_in = findCancelingSplit(cat, std::ref(use_def_chain));
Expand Down Expand Up @@ -261,4 +284,10 @@ void MoveSplitCatPass::runPass(Fusion* fusion) {
}
}

} // namespace

void MoveSplitCatPass::runPass(Fusion* fusion) {
CancelSplitCat(fusion).run();
}

} // namespace nvfuser::preseg_passes