Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,18 +399,25 @@ TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb);
/*!
* \brief Get the use-def chain of variables inside a function.
*
* \param fn The function to be analyzed.
* \return A map from variable definitions to a set of uses and variables needed by return value.
* \param expr The expression to be analyzed.
*
* \return A tuple of variable usage and variable outputs. The first
* element is a map from variable definitions to the set of downstream
* users of that definition. The second element is a list of
* variables whose usage occurs outside of any variable binding,
* typically the output body of a relax::Function or a relax::SeqExpr.
*/
std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn);
std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Expr& expr);

/*!
* \brief Remove unused statements inside DataflowBlocks.
*
* \param fn The function to remove unused statements.
* \return The function that contains no unused statements in DataflowBlock.
* \param expr The expression (typically a relax::Function) from which
* to remove unused statements.
*
* \return The updated function with no unused statements in DataflowBlock.
*/
TVM_DLL Function RemoveAllUnused(const Function fn);
TVM_DLL Expr RemoveAllUnused(Expr expr);

/*!
* \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
Expand Down
7 changes: 4 additions & 3 deletions src/relax/analysis/udchain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ class UDChain : public relax::ExprVisitor {
// nullptr users means it is the output of the function.
std::map<const VarNode*, std::set<const VarNode*>> to_users;

const VarNode* cur_user_;
const VarNode* cur_user_{nullptr};

void VisitBinding_(const VarBindingNode* binding) override {
// init
auto cache = cur_user_;
cur_user_ = binding->var.get();
this->VisitVarDef(binding->var);
this->VisitExpr(binding->value);
cur_user_ = nullptr;
cur_user_ = cache;
}

void VisitExpr_(const VarNode* op) override { to_users[op].insert(cur_user_); }
Expand All @@ -63,7 +64,7 @@ class UDChain : public relax::ExprVisitor {
};

std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> FunctionUseDef(
const Function& fn) {
const Expr& fn) {
UDChain udchain;
udchain.VisitExpr(fn);

Expand Down
51 changes: 22 additions & 29 deletions src/relax/ir/binding_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,38 +245,31 @@ class RemoveUnusedVars : public ExprMutator {
RemoveUnusedVars(Map<Var, Array<Var>> users, Array<Var> fn_outputs)
: RemoveUnusedVars(GetUnusedVars(users, fn_outputs)) {}

BindingBlock VisitBindingBlock_(const BindingBlockNode* block) override {
builder_->BeginBindingBlock();
for (Binding binding : block->bindings) {
bool can_remove = [&]() -> bool {
if (!unused_vars.count(binding->var)) {
return false;
}
auto var_binding = binding.as<VarBindingNode>();
if (!var_binding) {
return false;
}
return var_binding->value->IsInstance<FunctionNode>();
}();
if (!can_remove) {
VisitBinding(binding);
}
void VisitBinding_(const VarBindingNode* binding) override {
bool can_remove = unused_vars.count(binding->var) &&
(in_dataflow_block_ || !ContainsImpureCall(binding->value));
if (!can_remove) {
ExprMutator::VisitBinding_(binding);
}
return builder_->EndBlock();
}

BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
auto prev_dfb = GetRef<DataflowBlock>(block);
builder_->BeginDataflowBlock();
for (Binding binding : block->bindings) {
if (!unused_vars.count(binding->var) || binding.as<MatchCastNode>()) {
VisitBinding(binding);
}
bool capture_output = (block == caught_rewrite.get());

bool cache = in_dataflow_block_;
in_dataflow_block_ = true;
BindingBlock output = ExprMutator::VisitBindingBlock_(block);
in_dataflow_block_ = cache;

if (capture_output) {
caught_rewrite = Downcast<DataflowBlock>(output);
}
auto new_dfb = builder_->EndBlock();
if (caught_rewrite == prev_dfb) caught_rewrite = Downcast<DataflowBlock>(new_dfb);
return std::move(new_dfb);

return std::move(output);
}

private:
bool in_dataflow_block_{false};
};

void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) {
Expand Down Expand Up @@ -327,10 +320,10 @@ void DataflowBlockRewriteNode::RemoveAllUnused() {
TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused")
.set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); });

Function RemoveAllUnused(Function fn) {
auto [users, outputs] = FunctionUseDef(fn);
Expr RemoveAllUnused(Expr expr) {
auto [users, outputs] = FunctionUseDef(expr);
RemoveUnusedVars remover(users, outputs);
return Downcast<Function>(remover.VisitExpr_(fn.get()));
return remover.VisitExpr(std::move(expr));
}

TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused);
Expand Down
2 changes: 1 addition & 1 deletion src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ class PatternRewriter : ExprMutator {
params.insert(p.get());
}
PatternRewriter rewriter(pat, rewriter_func, params);
return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
}

void VisitBinding_(const VarBindingNode* binding) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/dead_code_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array<runtime::String> ent
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<Function>()) {
auto new_func = RemoveAllUnused(opt.value());
auto new_func = Downcast<Function>(RemoveAllUnused(opt.value()));
if (!new_func.same_as(base_func)) {
updates->Add(gvar, new_func);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ConstantFolder : public ExprMutator {
public:
static Function Fold(Function func, IRModule ctx_module) {
ConstantFolder folder(std::move(ctx_module));
func = RemoveAllUnused(Downcast<Function>(folder(func)));
func = Downcast<Function>(RemoveAllUnused(folder(func)));
return func;
}

Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ class GradientMutator : private ExprMutator {
new_func = CallTIRWithGradEliminator::Transform(new_func);

if (remove_all_unused) {
new_func = RemoveAllUnused(new_func);
new_func = Downcast<Function>(RemoveAllUnused(new_func));
}

// Step 5.3 mark the transformed function as public
Expand Down
126 changes: 111 additions & 15 deletions tests/python/relax/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:


def test_binding_block_remove_all_unused():
"""Remove unused dataflow bindings

Removal of unused bindings may not remove side effects. Since
bindings within a dataflow block are guaranteed not to have side
effects, they may be removed if unused.
"""

@tvm.script.ir_module
class IdentityUnused:
@R.function
Expand Down Expand Up @@ -117,24 +124,49 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])


def test_binding_block_remove_all_unused_without_dataflow():
@tvm.script.ir_module
class IdentityUnused:
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
lv0 = x
unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_dps_packed(
"my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
)
z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32")))
return z
def test_binding_block_remove_unused_pure_without_dataflow():
"""Remove unused dataflow bindings

optimized = remove_all_unused(IdentityUnused["main"])
Removal of unused bindings may not remove side effects. Unused
bindings whose value is a pure operation
(e.g. `R.call_dps_packed`) may be removed, even if outside of a
dataflow block.
"""

GroundTruth = IdentityUnused
@R.function(private=True)
def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
lv0 = x
unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32"))
return x

tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
@R.function(private=True)
def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
return x

after = remove_all_unused(before)
tvm.ir.assert_structural_equal(expected, after)


def test_binding_block_keep_impure_without_dataflow():
"""Remove unused dataflow bindings

Removal of unused bindings may not remove side effects. Unused
bindings whose value is an impure operation (e.g. `R.call_packed`)
may not be removed, as outside of a dataflow block they may
contain side effects.
"""

@R.function(private=True)
def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
lv0 = x
y = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32")))
return y

expected = before

after = remove_all_unused(before)
tvm.ir.assert_structural_equal(expected, after)


def test_binding_block_remove_all_unused_func_without_dataflow():
Expand Down Expand Up @@ -226,6 +258,70 @@ def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(dtype="int32", ndim=3):
tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])


def test_remove_all_unused_from_dataflow_block():
"""Like test_chained_remove_all_unused, but on a SeqExpr"""

@R.function
def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
lv0 = x
unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_dps_packed(
"my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
)
R.output(lv0)
return lv0

@R.function
def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
lv0 = x
R.output(lv0)
return lv0

after = remove_all_unused(before.body)
tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)


def test_remove_all_unused_from_binding_block():
"""Like test_chained_remove_all_unused, but on a SeqExpr"""

@R.function
def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
lv0 = x
unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32"))
return lv0

@R.function
def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
lv0 = x
return lv0

after = remove_all_unused(before.body)
tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)


def test_retain_impure_calls_unused_in_binding_block():
"""An impure call may have side effects, and must be kept"""

@R.function
def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
lv0 = x
unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_dps_packed("my_unused_call", (lv0,), R.Tensor((32, 32), dtype="float32"))
return lv0

@R.function
def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
lv0 = x
unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32"))
return lv0

after = remove_all_unused(before.body)
tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)


def test_name_to_binding_var_shadowing():
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
Expand Down
15 changes: 3 additions & 12 deletions tests/python/relax/test_transform_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def before(c0: R.Tensor((16, 16), "float32")):

@R.function
def expected(c1: R.Tensor((16, 16), "float32")):
lv0 = c1
return c1

c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
Expand Down Expand Up @@ -104,7 +103,6 @@ def before(c0: R.Tensor((2, 3), "float32")):

@R.function
def expected(c1: R.Tensor((3, 2), "float32")):
lv0 = c1
return c1

c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3)
Expand Down Expand Up @@ -135,8 +133,6 @@ def before(c0: R.Tensor((2, 2), "float32")):

@R.function
def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2), "float32")):
lv0 = c1
lv1 = c2
return c2

c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2)
Expand Down Expand Up @@ -218,27 +214,23 @@ def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)):
lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype="float32"))
# this line can not be folded because x's shape is unknown
lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype="float32"))
return lv3
return (lv0, lv3)

@R.function
def expected(
c0: R.Tensor((16, 16), "float32"),
c1: R.Tensor((16, 16), "float32"),
c2: R.Tensor((16, 16), "float32"),
x: R.Tensor("float32", ndim=2),
) -> R.Tensor:
):
n, m = T.int64(), T.int64()
cls = Module
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
# this line cannot be folded because n is unknown
lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32"))
# this line can be folded
lv1 = c1
# this line can be folded because all inputs are const
lv2 = c2
# this line can not be folded because x's shape is unknown
lv3 = relax.call_tir(cls.sub, (c2, x), R.Tensor((16, 16), dtype="float32"))
return lv3
return (lv0, lv3)

c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np + 1
Expand Down Expand Up @@ -268,7 +260,6 @@ def before(c0: R.Tensor((16, 16), "int32")):

@R.function
def expected(c1: R.Tensor((16, 16), "int32")):
lv0 = c1
return c1

c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16)
Expand Down
1 change: 0 additions & 1 deletion tests/python/relax/test_tuning_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def before(c0: R.Tensor((16, 16), "int32")):
# Expected IRModule after transformation.
@R.function
def expected(c1: R.Tensor((16, 16), "int32")):
lv0 = c1
return c1


Expand Down