From 4d704a4ca5a42f00f527fc4bd8936aacce57ef1f Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 17 Jan 2017 19:24:16 -0800 Subject: [PATCH 1/2] [PASS] Assign unique names to variables in ConvertSSA pass --- include/tvm/ir_pass.h | 3 ++- src/c_api/c_api_pass.cc | 18 +++++++------- src/pass/ssa.cc | 43 +++++++++++++++++++++------------ tests/python/test_pass_basic.py | 8 ++++-- 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index a45bbbb91fd8..1bf9323b46ae 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -57,7 +57,8 @@ Stmt ScheduleOps(Schedule s, Map dom_map); bool VerifySSA(const Stmt& ir); /*! - * \brief Convert a IR node to be SSA form. + * \brief Convert a IR node to be SSA form and assign a unique name + * to each variable. * \param stmt The source statement to be converted. * \return The converted form. */ diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index 10ffe95f653d..e45e25a265d0 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -15,22 +15,22 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_pass_Simplify) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - if (dynamic_cast(args.at(0).sptr.get())) { - *ret = Simplify(args.at(0).operator Expr()); - } else { + if (dynamic_cast(args.at(0).sptr.get())) { *ret = Simplify(args.at(0).operator Stmt()); + } else { + *ret = Simplify(args.at(0).operator Expr()); } }); TVM_REGISTER_API(_pass_Equal) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - CHECK(args.at(1).type_id == kNodeHandle); - if (dynamic_cast(args.at(0).sptr.get())) { - *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); - } else { + if (dynamic_cast(args.at(0).sptr.get())) { + CHECK(args.at(1).type_id == kNodeHandle); *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); + } else { + Expr a = args.at(0).operator Expr(); + Expr b = args.at(1).operator Expr(); + *ret = Equal(a, b); } }); diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 9b01e24fee38..08e7348fc5e8 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -94,23 +94,29 @@ class IRConvertSSA : public IRMutator { static auto& fset_var_def = FSetVarDef::vtable_expr(); if (fget_var_def.can_dispatch(expr)) { VarExpr v = fget_var_def(expr); - VarExpr new_var = v; + std::string old_name = v->name_hint; + // Sanity check for Allocate if (defined_.count(v.get()) != 0) { CHECK(expr.as() == nullptr) << "One allocation in two places, cannot rename buffer in allocate"; - new_var = Variable::make(v->type, v->name_hint); } else { defined_.insert(v.get()); } + // assign a unique name to variable + std::stringstream new_name; + if (name_index_.count(old_name) != 0) { + new_name << old_name << name_index_[old_name]; + name_index_[old_name] += 1; + } else { + new_name << old_name << "0"; + name_index_[old_name] = 1; + } + VarExpr new_var = Variable::make(v->type, new_name.str()); scope_[v.get()].push_back(new_var); Expr new_expr = IRMutator::Mutate(expr); scope_[v.get()].pop_back(); - if (!new_var.same_as(v)) { - return fset_var_def(new_expr, new_var); - } else { - return new_expr; - } + return fset_var_def(new_expr, new_var); } else if (expr.as()) { const Variable* v = expr.as(); if (scope_.count(v) != 0) { @@ -129,21 +135,25 @@ class IRConvertSSA : public IRMutator { static auto& fset_var_def = FSetVarDef::vtable_stmt(); if (fget_var_def.can_dispatch(stmt)) { VarExpr v = fget_var_def(stmt); - VarExpr new_var = v; - if (defined_.count(v.get()) != 0) { - new_var = Variable::make(v->type, v->name_hint); - } else { + std::string old_name = v->name_hint; + if (defined_.count(v.get()) == 0) { defined_.insert(v.get()); } + // assign a unique name to variable + std::stringstream new_name; + if (name_index_.count(old_name) != 0) { + new_name << old_name << name_index_[old_name]; + name_index_[old_name] += 1; + } else { + new_name << old_name << "0"; + name_index_[old_name] = 1; + } + VarExpr new_var = Variable::make(v->type, new_name.str()); scope_[v.get()].push_back(new_var); Stmt new_stmt = IRMutator::Mutate(stmt); scope_[v.get()].pop_back(); - if (!new_var.same_as(v)) { - return fset_var_def(new_stmt, new_var); - } else { - return new_stmt; - } + return fset_var_def(new_stmt, new_var); } else { return IRMutator::Mutate(stmt); } @@ -152,6 +162,7 @@ class IRConvertSSA : public IRMutator { private: std::unordered_map > scope_; std::unordered_set defined_; + std::unordered_map name_index_; }; } // namespace diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py index ebffc58805f3..b9e8d501e68a 100644 --- a/tests/python/test_pass_basic.py +++ b/tests/python/test_pass_basic.py @@ -8,6 +8,9 @@ def test_simplify(): assert(tvm.ir_pass.Equal(e2, x * 8)) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) + let = tvm.make.Let(x, 1, x + 3) + e4 = tvm.ir_pass.Simplify(let) + assert(tvm.ir_pass.Equal(e4, 4)) def test_verify_ssa(): @@ -20,8 +23,9 @@ def test_verify_ssa(): def test_convert_ssa(): x = tvm.Var('x') y = tvm.Var() - let = tvm.make.Let(x, 1, x + 1) - z = tvm.make.Evaluate(let + let) + let1 = tvm.make.Let(x, 1, x + 1) + let2 = tvm.make.Let(x, 1, x + y) + z = tvm.make.Evaluate(let1 + let2) assert(not tvm.ir_pass.VerifySSA(z)) z_ssa = tvm.ir_pass.ConvertSSA(z) assert(tvm.ir_pass.VerifySSA(z_ssa)) From aa8edec20b3aac23fe0f5648b738603855cee292 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 18 Jan 2017 14:45:47 -0800 Subject: [PATCH 2/2] revert change to ConverSSA pass --- include/tvm/ir_pass.h | 3 +-- src/pass/ssa.cc | 43 ++++++++++++++++--------------------------- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 1bf9323b46ae..a45bbbb91fd8 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -57,8 +57,7 @@ Stmt ScheduleOps(Schedule s, Map dom_map); bool VerifySSA(const Stmt& ir); /*! - * \brief Convert a IR node to be SSA form and assign a unique name - * to each variable. + * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. * \return The converted form. */ diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 08e7348fc5e8..9b01e24fee38 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -94,29 +94,23 @@ class IRConvertSSA : public IRMutator { static auto& fset_var_def = FSetVarDef::vtable_expr(); if (fget_var_def.can_dispatch(expr)) { VarExpr v = fget_var_def(expr); - std::string old_name = v->name_hint; - // Sanity check for Allocate + VarExpr new_var = v; if (defined_.count(v.get()) != 0) { CHECK(expr.as() == nullptr) << "One allocation in two places, cannot rename buffer in allocate"; + new_var = Variable::make(v->type, v->name_hint); } else { defined_.insert(v.get()); } - // assign a unique name to variable - std::stringstream new_name; - if (name_index_.count(old_name) != 0) { - new_name << old_name << name_index_[old_name]; - name_index_[old_name] += 1; - } else { - new_name << old_name << "0"; - name_index_[old_name] = 1; - } - VarExpr new_var = Variable::make(v->type, new_name.str()); scope_[v.get()].push_back(new_var); Expr new_expr = IRMutator::Mutate(expr); scope_[v.get()].pop_back(); - return fset_var_def(new_expr, new_var); + if (!new_var.same_as(v)) { + return fset_var_def(new_expr, new_var); + } else { + return new_expr; + } } else if (expr.as()) { const Variable* v = expr.as(); if (scope_.count(v) != 0) { @@ -135,25 +129,21 @@ class IRConvertSSA : public IRMutator { static auto& fset_var_def = FSetVarDef::vtable_stmt(); if (fget_var_def.can_dispatch(stmt)) { VarExpr v = fget_var_def(stmt); - std::string old_name = v->name_hint; - if (defined_.count(v.get()) == 0) { - defined_.insert(v.get()); - } - // assign a unique name to variable - std::stringstream new_name; - if (name_index_.count(old_name) != 0) { - new_name << old_name << name_index_[old_name]; - name_index_[old_name] += 1; + VarExpr new_var = v; + if (defined_.count(v.get()) != 0) { + new_var = Variable::make(v->type, v->name_hint); } else { - new_name << old_name << "0"; - name_index_[old_name] = 1; + defined_.insert(v.get()); } - VarExpr new_var = Variable::make(v->type, new_name.str()); scope_[v.get()].push_back(new_var); Stmt new_stmt = IRMutator::Mutate(stmt); scope_[v.get()].pop_back(); - return fset_var_def(new_stmt, new_var); + if (!new_var.same_as(v)) { + return fset_var_def(new_stmt, new_var); + } else { + return new_stmt; + } } else { return IRMutator::Mutate(stmt); } @@ -162,7 +152,6 @@ class IRConvertSSA : public IRMutator { private: std::unordered_map > scope_; std::unordered_set defined_; - std::unordered_map name_index_; }; } // namespace