Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6061f3b
Initial draft of empty remover pass
jacobhinkle Jun 28, 2023
4bce330
Add NVFUSER_DUMP=fusion_ir_preseg to view opt pass output
jacobhinkle Jun 28, 2023
6c64915
Add FusionRemoveEmptyOutput_CUDA
jacobhinkle Jun 28, 2023
6fb368c
Fill with zeroVal instead of oneVal
jacobhinkle Jun 28, 2023
0d8b4fc
Add FusionRemoveEmptyReduction_CUDA
jacobhinkle Jun 28, 2023
436832f
Handle CatOp, with test
jacobhinkle Jun 28, 2023
50d763f
Move tests to test_optimization_pass.cpp
jacobhinkle Jun 28, 2023
56dcfae
Handle PadOp
jacobhinkle Jun 28, 2023
c776878
Use deque instead of stack.
jacobhinkle Jun 28, 2023
fb1b670
Use BackwardVisitor
jacobhinkle Jun 28, 2023
f6487e0
Handle WelfordOp
jacobhinkle Jun 28, 2023
b9953f7
Silence clang-tidy and convert empty check to TORCH_WARN
jacobhinkle Jun 28, 2023
32ab55a
Use TORCH_WARN_ONCE instead of TORCH_WARN
jacobhinkle Jun 28, 2023
7b1d763
Cleanup, mark upstream unused tensors dead
jacobhinkle Jun 29, 2023
bc57aa3
Add live statement tracking.
jacobhinkle Jun 29, 2023
cf24bd0
Refactor into parent class DeadCodeRemover
jacobhinkle Jun 29, 2023
4d0c656
Update comment for DeadCodeRemover
jacobhinkle Jun 29, 2023
d1532c2
Comment update
jacobhinkle Jun 29, 2023
4f83f3f
Update comment
jacobhinkle Jun 29, 2023
940dd73
Use getStmts in DeadCodeRemover::run
jacobhinkle Jun 29, 2023
5349160
Simplify PadOp handling
jacobhinkle Jun 29, 2023
267011c
Handle MmaOp
jacobhinkle Jun 30, 2023
d84f7fe
Use public inheritance in EmptyTensorRemover
jacobhinkle Jun 30, 2023
ea2374c
Create allOutputsDead method
jacobhinkle Jun 30, 2023
d7d43fc
Convert empty intermediate warning to comment
jacobhinkle Jun 30, 2023
1afed11
Revert "Convert empty intermediate warning to comment"
jacobhinkle Jun 30, 2023
b5c46ef
Convert intermediate empty check to assertion
jacobhinkle Jun 30, 2023
6d62c13
Move assert to else block.
jacobhinkle Jun 30, 2023
acdfa4c
Test empty-output reduction -> reduction over empty
jacobhinkle Jun 30, 2023
3591bfa
Move DeadCodeRemover to iter_visitor.cpp
jacobhinkle Jun 30, 2023
53d51ed
Fix constness and use all_of in allUsesDead. Check I/O in markDead
jacobhinkle Jul 3, 2023
b4c2372
Rework interface: only replaceVal& removeVal used in child classes
jacobhinkle Jul 3, 2023
4373c1b
Return vector<int64_t> in emptyAxes
jacobhinkle Jul 3, 2023
0a1d21f
Defer removal of Vals and Exprs until after traversal
jacobhinkle Jul 3, 2023
f2270e6
Merge branch 'main' into empty_branch_opt_pass
jacobhinkle Jul 3, 2023
fa7aed1
Defer modifying Fusion at all until after traversal.
jacobhinkle Jul 6, 2023
d1c4b8f
Merge remote-tracking branch 'origin/main' into empty_branch_opt_pass
jacobhinkle Jul 6, 2023
c1c386f
Merge branch 'main' into empty_branch_opt_pass
jacobhinkle Jul 6, 2023
7ed25be
Rename replaceVal as registerReplacement.
jacobhinkle Jul 6, 2023
6cecc75
Merge branch 'main' into empty_branch_opt_pass
jacobhinkle Jul 6, 2023
9a61d1e
Merge branch 'main' into empty_branch_opt_pass
jacobhinkle Jul 7, 2023
9609922
Return bool from run() and modifyFusion()
jacobhinkle Jul 7, 2023
8d45d43
Merge branch 'main' into empty_branch_opt_pass
jacobhinkle Jul 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/optimization/add_axioms.cpp
${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp
${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp
${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp
)

if(BUILD_PYTHON)
Expand Down
159 changes: 159 additions & 0 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,4 +933,163 @@ std::vector<Val*> InputsOf::outputs(
return io.ordered_inputs;
}

/* DEAD CODE REMOVER */
bool DeadCodeRemover::run() {
// First we build a set of all live Statements so that we can detect dead
// branches.
for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) {
markLive(stmt);
}

// Note that StmtSort::getStmts() is also run in traverseTo. In the future,
// we could potentially refactor this so that derived classes from
// BackwardVisitor can make use of that traversal instead of repeating it.
traverseTo(fusion_, fusion_->outputs(), false);

// We do not remove Statements from the Fusion while traversing, to avoid
// dereferencing invalid pointers. Instead, we wait until this point to do the
// removal.
return modifyFusion();
}

void DeadCodeRemover::handle(Statement* stmt) {
if (isDead(stmt)) {
// We check whether stmt is dead before we dereference it, since it may
// have been removed from the Fusion.
return;
}
BackwardVisitor::handle(stmt);
}

void DeadCodeRemover::handle(Expr* expr) {
if (maybeRemoveExpr(expr)) {
// maybeRemoveExp will remove expr from the Fusion if all its uses are
// marked dead. In that case, we should not continue handling it since the
// expr pointer is invalid.
return;
}
BackwardVisitor::handle(expr);
}

void DeadCodeRemover::handle(TensorView* tv) {
if (!tv->isFusionOutput() && !tv->isFusionInput() && allUsesDead(tv)) {
if (!markDead(tv)) {
return;
}

if (tv->definition()) {
// If tv has a definition, it can only be removed by removing its
// definition
maybeRemoveExpr(tv->definition());
} else {
registerRemoval(tv);
}
return;
}
BackwardVisitor::handle(tv);
}

bool DeadCodeRemover::registerReplacement(Val* old_val, Val* new_val) {
vals_to_replace_.emplace_back(old_val, new_val);

if (old_val->isFusionInput()) {
// Skip removing Fusion inputs
return false;
}
TORCH_CHECK(
old_val->definition(),
"Found non-input ",
old_val->toString(),
" with no definition.");

// Mark old_val dead even if we can't yet remove it due to its definition
// having some live outputs
TORCH_CHECK(
markDead(old_val),
"Attempted to replace ",
old_val->toString(),
" which was previously marked dead.");

// If old_val has a definition, it can only be removed by removing its
// definition
return maybeRemoveExpr(old_val->definition());
}

bool DeadCodeRemover::maybeRemoveExpr(Expr* expr) {
if (allOutputsDead(expr)) {
if (!markDead(expr)) {
// Expr was already marked dead, so don't try to remove it again
return false;
}

const auto outputs = expr->outputs();
for (auto outp : outputs) {
registerRemoval(outp);
}
registerRemoval(expr);
return true;
} else {
return false;
}
}

void DeadCodeRemover::registerRemoval(Val* val) {
TORCH_INTERNAL_ASSERT(
!val->isFusionInput(),
"Call to registerRemoval on Fusion input is illegal: ",
val->toString());
vals_to_remove_.push_back(val);
}

void DeadCodeRemover::markLiveRecursive(Statement* stmt) {
if (isLive(stmt)) {
return;
}
markLive(stmt);
if (stmt->isVal() && stmt->asVal()->definition()) {
markLiveRecursive(stmt);
} else {
auto expr = stmt->asExpr();
for (const auto inp : expr->outputs()) {
markLive(inp);
}
for (const auto inp : expr->inputs()) {
markLiveRecursive(inp);
}
}
}

bool DeadCodeRemover::markDead(Statement* stmt) {
return (bool)live_statements_.erase(stmt);
}

bool DeadCodeRemover::modifyFusion() const {
bool modified_fusion = false;
for (auto [old_val, new_val] : vals_to_replace_) {
if (old_val->isFusionOutput()) {
fusion_->replaceOutput(old_val, new_val);
}
for (auto use : old_val->uses()) {
ir_utils::replaceValInExpr(use, old_val, new_val);
}
modified_fusion = true;
}
for (auto val : vals_to_remove_) {
fusion_->removeVal(val);
modified_fusion = true;
}
for (auto expr : exprs_to_remove_) {
// Fusion::removeVal(val) actually removes val->definition() from the
// Fusion, and sets all its outputs' definitions to nullptr. So we should
// not need to manually remove Exprs here. Instead, we just assert that
// they have already been removed if they were registered for removal.
TORCH_INTERNAL_ASSERT(
!fusion_->inContainer(expr),
"Expression ",
expr->toString(),
" was marked for removal but has not yet been removed.");
}
return modified_fusion;
}

} // namespace nvfuser
149 changes: 146 additions & 3 deletions csrc/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <c10/macros/Export.h>

#include <dispatch.h>
#include <ir/base_nodes.h>
#include <type.h>

#include <deque>
Expand All @@ -19,9 +20,6 @@
namespace nvfuser {

class Fusion;
class Statement;
class Expr;
class Val;

/*
* IterVisitor starts from leaf nodes, fusion outputs, or the provided values.
Expand Down Expand Up @@ -372,4 +370,149 @@ class TORCH_CUDA_CU_API InputsOf : public IterVisitor {
const std::vector<Val*>& outputs_);
};

//! This is a generic traversal class that is used to modify a Fusion graph by
//! replacing Vals to simplify computation or remove dead code. This differs
//! from OptOutMutator, which is built for mutating TensorViews in-place in a
//! graph by altering the associated IterDomains, and which does not easily
//! handle modifying TensorView definitions and Fusion outputs during traversal.
//!
//! Derived classes should override handle() for relevant Exprs and they should
//! make use of registerReplacement() to change the definitions of Vals in the
//! graph. Note that if replacements are made using registerReplacement(old_val,
//! new_val), then neither new_val nor any new Statements produced in creating
//! it will be traversed by this class. Also note that any Vals or Exprs that
//! are previously marked dead will not be processed by handle().
class DeadCodeRemover : BackwardVisitor {
public:
DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {}

DeadCodeRemover(const DeadCodeRemover& other) = default;
DeadCodeRemover& operator=(const DeadCodeRemover& other) = default;

DeadCodeRemover(DeadCodeRemover&& other) = default;
DeadCodeRemover& operator=(DeadCodeRemover&& other) = default;

//! Instead of traverseTo, run() is the entry point for this class, and we
//! always traverse from outputs backward to their inputs.
//!
//! Returns a bool indicating whether the Fusion was modified or not.
bool run();

inline Fusion* fusion() const {
return fusion_;
}

protected:
using BackwardVisitor::handle;

void handle(Statement* stmt) override;
void handle(Expr* expr) override;

//! We implement this in order to remove dangling TensorViews whose uses are
//! all dead. Note that we do not remove other ValTypes like Scalars since
//! they might still be used as attributes or members of other objects, which
//! is not reflected by Val::uses().
void handle(TensorView* tv) override;

//! Registers a Val for replacement in outputs and in all its uses.
//!
//! Note that replacement does not occur immediately, but will be done after
//! the traversal is completed. This is so that any Val* and Expr* pointers
//! may be safely dereferenced during traversal.
//!
//! The argument old_val is always marked Dead by this method. If old_val is a
//! Fusion input, we do not replace it. If old_val's definition is non-null
//! and has other outputs which are not dead, we do not remove old_val.
//!
//! Returns whether old_val was registered for removal from the Fusion.
bool registerReplacement(Val* old_val, Val* new_val);

//! Find whether a statement is not marked as live code.
inline bool isDead(Statement* stmt) const {
return live_statements_.find(stmt) == live_statements_.end();
}

//! Find whether a statement is marked as live code.
inline bool isLive(Statement* stmt) const {
return !isDead(stmt);
}

//! Check whether all outputs of an expression have been marked dead
inline bool allOutputsDead(Expr* expr) const {
return std::all_of(
expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) {
return isDead(outp);
});
}

//! Check whether all uses have been marked dead
inline bool allUsesDead(Val* val) const {
return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) {
return isDead(use);
});
}

private:
//! Removes an Expr* from the Fusion, if possible.
//!
//! The Expr will _only_ be marked dead and removed if all of its outputs are
//! already marked dead. In this case all the outputs will also be removed
//! from the Fusion.
//!
//! Returns whether the Expr was marked dead and removed from the Fusion.
bool maybeRemoveExpr(Expr* expr);

//! Mark a single Statement as being alive.
inline void markLive(Statement* stmt) {
live_statements_.insert(stmt);
}

//! Ensure that a Statement and its upstream Statements are alive. If it is an
//! Expr, ensure all its inputs are alive. If it's a Val with a definition,
//! recursive to the definition. Newly-created Statements default to being
//! dead, so this method is called when adding a Statement to the active path
//! of the Fusion inside registerReplacement.
void markLiveRecursive(Statement* stmt);

//! Mark a single Statement as being dead. This does not remove stmt from the
//! Fusion. It is an error to call this on a Fusion output.
//!
//! Returns true if the statement was previously live, and false otherwise.
bool markDead(Statement* stmt);

//! Register a Val for later removal.
void registerRemoval(Val* val);

//! Register an Expr for later removal.
//!
//! Note that if any of its outputs are removed, expr will be removed even if
//! it is not marked for removal, and all its outputs will have their
//! definitions set to nullptr.
inline void registerRemoval(Expr* expr) {
exprs_to_remove_.push_back(expr);
}

//! All modifications to the Fusion are registered during traversal then
//! later they are committed by this method. For safety, this should only be
//! run after traversing the graph.
//!
//! Returns a bool indicating whether any modifications were performed.
bool modifyFusion() const;

private:
//! The Fusion associated with live_statements_
Fusion* fusion_;

//! Statements are marked dead by removing them from this set
std::unordered_set<Statement*> live_statements_;

//! Vals to be replaced in outputs and with replaceValInExpr in all uses.
std::vector<std::pair<Val*, Val*>> vals_to_replace_;

//! Statements that will be removed. We remove Vals before Exprs, so we track
//! them separately here.
std::vector<Val*> vals_to_remove_;
std::vector<Expr*> exprs_to_remove_;
};

} // namespace nvfuser
6 changes: 6 additions & 0 deletions csrc/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,12 @@ FusionKernelRuntime::FusionKernelRuntime(
optimization::OptimizationPass<optimization::PreSegmenter>::runPass(
fusion.get());

if (isDebugDumpEnabled(DebugDumpOption::FusionIrPreseg)) {
std::cout << "Fusion IR after pre-segmenter optimization passes:"
<< std::endl;
fusion->printMath();
}
Comment on lines +676 to +680
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New dump option fusion_ir_preseg to more easily monitor what the optimization passes are doing.


all_tvs_ = ir_utils::allTvs(fusion.get());

// Run segmentation on the copied fusion
Expand Down
3 changes: 3 additions & 0 deletions csrc/optimization/pre_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

#include <optimization/add_axioms.h>
#include <optimization/consecutive_cast.h>
#include <optimization/remove_empty.h>

namespace nvfuser::optimization {

void PreSegmenter::runPass(Fusion* fusion) {
// Replace TensorViews with zero extent. Outputs and inputs may still be empty
OptimizationPass<RemoveEmptyPass>::runPass(fusion);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I placed this pass first since I assumed we may want to do any DCE passes before other patterns are matched.

// removes consecutive cast operations
OptimizationPass<ConsecutiveCastPass>::runPass(fusion);
OptimizationPass<AddAxiomsPass>::runPass(fusion);
Expand Down
Loading