From d15bbf6f34876e018f5f9d6c547906611af7fb3d Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 8 Feb 2017 11:08:02 +0800 Subject: [PATCH 1/8] [FUSION] add Fusion(Schedule) --- include/tvm/schedule_pass.h | 2 + python/tvm/schedule.py | 2 +- src/api/api_schedule.cc | 1 + src/pass/ir_util.h | 2 + src/schedule/fusion.cc | 85 +++++++++++++++++++ .../unittest/test_schedule_schedule_ops.py | 16 ++++ 6 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 src/schedule/fusion.cc diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index 57e442c5c15e..845ec893a547 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -33,6 +33,8 @@ Map InferBound(Schedule sch); */ Stmt ScheduleOps(Schedule s, Map dom_map); +Schedule Fusion(Schedule sch); + } // namespace schedule } // namespace tvm #endif // TVM_SCHEDULE_PASS_H_ diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 3fd7f9730d46..fee0fb3b1274 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -135,7 +135,7 @@ def compute_root(self): parent : Stage The parent stage """ - _api_internal._StageComputeInline(self) + _api_internal._StageComputeRoot(self) def reorder(self, *args): """reorder the arguments in the specified order. diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index a4462117d494..03ca292b44a6 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -28,6 +28,7 @@ namespace schedule { REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(CreateReadGraph); +REGISTER_SCHEDULE_PASS1(Fusion); REGISTER_SCHEDULE_PASS2(PostDFSOrder); REGISTER_SCHEDULE_PASS2(ScheduleOps); diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 2fbff80995f6..88cae00a3777 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -70,6 +70,8 @@ inline Stmt MergeNest(std::vector > nest, Stmt body) { return body; } +bool IsEwise(Expr e, std::vector axis); + } // namespace ir } // namespace tvm #endif // TVM_PASS_IR_UTIL_H_ diff --git a/src/schedule/fusion.cc b/src/schedule/fusion.cc new file mode 100644 index 000000000000..e263a702be3a --- /dev/null +++ b/src/schedule/fusion.cc @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file schedule.cc + */ +#include +#include +#include "./graph.h" + +namespace tvm { + +namespace ir { + +static bool check_index(std::vector axis, Array index) { + if (axis.size() != index.size()) + return false; + + for (size_t i = 0; i < axis.size(); ++i) { + const Variable *v1 = axis[i].as(); + const Variable *v2 = index[i].as(); + if (!(v1 && v2) || (v1 != v2)) + return false; + } + return true; +} + +template +static bool check_binary_op(const T *n, std::vector axis) { + const Call *ac = n->a.template as(); + const Call *bc = n->b.template as(); + if (!(ac && bc)) + return false; + return (check_index(axis, ac->args) && check_index(axis, bc->args)); +} + +bool IsEwise(Expr e, std::vector axis) { + if (const Add *n = e.as()) { + return check_binary_op(n, axis); + } else if (const Sub *n = e.as()) { + return check_binary_op(n, axis); + } else if (const Mul *n = e.as()) { + return check_binary_op(n, axis); + } else if (const Div *n = e.as
()) { + return check_binary_op(n, axis); + } else if (const Mod *n = e.as()) { + return check_binary_op(n, axis); + } else if (const Min *n = e.as()) { + return check_binary_op(n, axis); + } else if (const Max *n = e.as()) { + return check_binary_op(n, axis); + } + return false; +} + +} // namespace ir + + +namespace schedule { + +Schedule Fusion(Schedule sch) { + auto g = schedule::CreateReadGraph(sch->roots); + Array post_order = schedule::PostDFSOrder(sch->roots, g); + for (Operation op : post_order) { + if (const ComputeOpNode* compute = op.as()) { + std::vector axis; + for (const auto& iter : compute->axis) { + axis.push_back(iter->var); + } + if (ir::IsEwise(compute->body, axis)) { + bool is_root = false; + for (auto r : sch->roots) { + if (r == op) { + is_root = true; + break; + } + } + if (!is_root) + sch[op].compute_inline(); + } + } + } + return sch; +} + +} // namespace schedule +} // namespace tvm diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index feed951e295f..bbc9685ac0b3 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -42,8 +42,24 @@ def test_schedule2(): stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) +def test_fusion(): + m = tvm.Var('m') + n = tvm.Var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + C = tvm.placeholder((m, n), name='C') + T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1') + T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2') + + s = tvm.Schedule(T2.op) + fs = tvm.schedule.Fusion(s) + bounds = tvm.schedule.InferBound(fs) + stmt = tvm.schedule.ScheduleOps(fs, bounds) + print(stmt) + if __name__ == "__main__": test_schedule0() test_schedule1() test_schedule2() + test_fusion() From 83f5c4f0a407a958a2b7477325ee0d1a336ed076 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 8 Feb 2017 15:06:02 +0800 Subject: [PATCH 2/8] [FUSION] rename to AutoFuseEwise, detect whether the stage has been scheduled --- include/tvm/schedule_pass.h | 2 +- src/api/api_schedule.cc | 6 +++++- src/schedule/fusion.cc | 9 ++++++--- tests/python/unittest/test_schedule_schedule_ops.py | 6 +++--- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index 845ec893a547..f1d74ef11e28 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -33,7 +33,7 @@ Map InferBound(Schedule sch); */ Stmt ScheduleOps(Schedule s, Map dom_map); -Schedule Fusion(Schedule sch); +void AutoFuseEwise(Schedule sch); } // namespace schedule } // namespace tvm diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 03ca292b44a6..0dd1fd4a7878 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -13,6 +13,11 @@ namespace tvm { namespace schedule { +TVM_REGISTER_API(_schedule_AutoFuseEwise) +.set_body([](TVMArgs args, TVMRetValue* ret) { + AutoFuseEwise(args[0]); + }); + #define REGISTER_SCHEDULE_PASS1(PassName) \ TVM_REGISTER_API(_schedule_## PassName) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ @@ -28,7 +33,6 @@ namespace schedule { REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(CreateReadGraph); -REGISTER_SCHEDULE_PASS1(Fusion); REGISTER_SCHEDULE_PASS2(PostDFSOrder); REGISTER_SCHEDULE_PASS2(ScheduleOps); diff --git a/src/schedule/fusion.cc b/src/schedule/fusion.cc index e263a702be3a..0ce97e9739ac 100644 --- a/src/schedule/fusion.cc +++ b/src/schedule/fusion.cc @@ -56,7 +56,11 @@ bool IsEwise(Expr e, std::vector axis) { namespace schedule { -Schedule Fusion(Schedule sch) { +static bool is_stage_scheduled(const Stage& s) { + return !(s->relations.empty() && s->attach_type == kNone); +} + +void AutoFuseEwise(Schedule sch) { auto g = schedule::CreateReadGraph(sch->roots); Array post_order = schedule::PostDFSOrder(sch->roots, g); for (Operation op : post_order) { @@ -65,7 +69,7 @@ Schedule Fusion(Schedule sch) { for (const auto& iter : compute->axis) { axis.push_back(iter->var); } - if (ir::IsEwise(compute->body, axis)) { + if (!is_stage_scheduled(sch[op]) && ir::IsEwise(compute->body, axis)) { bool is_root = false; for (auto r : sch->roots) { if (r == op) { @@ -78,7 +82,6 @@ Schedule Fusion(Schedule sch) { } } } - return sch; } } // namespace schedule diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index bbc9685ac0b3..4655d17f5e6f 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -52,9 +52,9 @@ def test_fusion(): T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2') s = tvm.Schedule(T2.op) - fs = tvm.schedule.Fusion(s) - bounds = tvm.schedule.InferBound(fs) - stmt = tvm.schedule.ScheduleOps(fs, bounds) + tvm.schedule.AutoFuseEwise(s) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) From f66b3c98539a645fce6cffb446b72ad3a30923e7 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 9 Feb 2017 03:06:10 +0000 Subject: [PATCH 3/8] [FUSION] change to visitor pattern --- include/tvm/ir_pass.h | 6 ++ include/tvm/schedule_pass.h | 7 +- src/api/api_schedule.cc | 4 +- src/pass/elem_wise_detector.cc | 57 +++++++++++++ src/pass/ir_util.h | 2 - src/schedule/fusion.cc | 82 ++++--------------- .../unittest/test_schedule_schedule_ops.py | 2 +- 7 files changed, 86 insertions(+), 74 deletions(-) create mode 100644 src/pass/elem_wise_detector.cc diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index b11486d9023a..10ed591a226a 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -167,6 +167,12 @@ Array SplitHostDevice(LoweredFunc func); */ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); +/*! + * \brief Whether the node is element-wise. + * \return whether the node is element-wise. + */ +bool IsElemWise(const NodeRef& node); + } // namespace ir } // namespace tvm diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index f1d74ef11e28..0506fee52662 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -33,7 +33,12 @@ Map InferBound(Schedule sch); */ Stmt ScheduleOps(Schedule s, Map dom_map); -void AutoFuseEwise(Schedule sch); +/*! + * \brief To automatically fuse the element-wise operations. + * + * \param s The schedule to be fused. + */ +void AutoFuseElemWise(Schedule sch); } // namespace schedule } // namespace tvm diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 0dd1fd4a7878..80284824d443 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -13,9 +13,9 @@ namespace tvm { namespace schedule { -TVM_REGISTER_API(_schedule_AutoFuseEwise) +TVM_REGISTER_API(_schedule_AutoFuseElemWise) .set_body([](TVMArgs args, TVMRetValue* ret) { - AutoFuseEwise(args[0]); + AutoFuseElemWise(args[0]); }); #define REGISTER_SCHEDULE_PASS1(PassName) \ diff --git a/src/pass/elem_wise_detector.cc b/src/pass/elem_wise_detector.cc new file mode 100644 index 000000000000..dd8df2454ba7 --- /dev/null +++ b/src/pass/elem_wise_detector.cc @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file elem_wise_detector.cc + */ +#include +#include +#include + +namespace tvm { +namespace ir { + +class ElemWiseDetector : public IRVisitor { + public: + explicit ElemWiseDetector(Array axis) : axis_(axis) {} + + void Visit(const NodeRef& e) final { + if (!is_elem_wise_) + return; + IRVisitor::Visit(e); + } + + void Visit_(const Call* op) final { + Array axis = op->args; + if (axis_.size() != axis.size()) { + is_elem_wise_ = false; + return; + } + + for (size_t i = 0; i < axis_.size(); ++i) { + const Variable *v1 = axis_[i]->var.as(); + const Variable *v2 = axis[i].as(); + if (!(v1 && v2) || (v1 != v2)) { + is_elem_wise_ = false; + return; + } + } + IRVisitor::Visit_(op); + } + + bool is_elem_wise_{true}; + + private: + Array axis_; +}; + + +bool IsElemWise(const NodeRef& node) { + if (const ComputeOpNode* compute = node.as()) { + ElemWiseDetector v = ElemWiseDetector(compute->axis); + v.Visit(compute->body); + return v.is_elem_wise_; + } + return false; +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 88cae00a3777..2fbff80995f6 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -70,8 +70,6 @@ inline Stmt MergeNest(std::vector > nest, Stmt body) { return body; } -bool IsEwise(Expr e, std::vector axis); - } // namespace ir } // namespace tvm #endif // TVM_PASS_IR_UTIL_H_ diff --git a/src/schedule/fusion.cc b/src/schedule/fusion.cc index 0ce97e9739ac..22a5ff69b27e 100644 --- a/src/schedule/fusion.cc +++ b/src/schedule/fusion.cc @@ -3,83 +3,29 @@ * \file schedule.cc */ #include -#include -#include "./graph.h" +#include namespace tvm { - -namespace ir { - -static bool check_index(std::vector axis, Array index) { - if (axis.size() != index.size()) - return false; - - for (size_t i = 0; i < axis.size(); ++i) { - const Variable *v1 = axis[i].as(); - const Variable *v2 = index[i].as(); - if (!(v1 && v2) || (v1 != v2)) - return false; - } - return true; -} - -template -static bool check_binary_op(const T *n, std::vector axis) { - const Call *ac = n->a.template as(); - const Call *bc = n->b.template as(); - if (!(ac && bc)) - return false; - return (check_index(axis, ac->args) && check_index(axis, bc->args)); -} - -bool IsEwise(Expr e, std::vector axis) { - if (const Add *n = e.as()) { - return check_binary_op(n, axis); - } else if (const Sub *n = e.as()) { - return check_binary_op(n, axis); - } else if (const Mul *n = e.as()) { - return check_binary_op(n, axis); - } else if (const Div *n = e.as
()) { - return check_binary_op(n, axis); - } else if (const Mod *n = e.as()) { - return check_binary_op(n, axis); - } else if (const Min *n = e.as()) { - return check_binary_op(n, axis); - } else if (const Max *n = e.as()) { - return check_binary_op(n, axis); - } - return false; -} - -} // namespace ir - - namespace schedule { -static bool is_stage_scheduled(const Stage& s) { +namespace { +inline bool is_stage_scheduled(const Stage& s) { return !(s->relations.empty() && s->attach_type == kNone); } +} -void AutoFuseEwise(Schedule sch) { - auto g = schedule::CreateReadGraph(sch->roots); - Array post_order = schedule::PostDFSOrder(sch->roots, g); - for (Operation op : post_order) { - if (const ComputeOpNode* compute = op.as()) { - std::vector axis; - for (const auto& iter : compute->axis) { - axis.push_back(iter->var); - } - if (!is_stage_scheduled(sch[op]) && ir::IsEwise(compute->body, axis)) { - bool is_root = false; - for (auto r : sch->roots) { - if (r == op) { - is_root = true; - break; - } +void AutoFuseElemWise(Schedule sch) { + for (Stage s : sch->stages) { + if (!is_stage_scheduled(s) && ir::IsElemWise(s->op)) { + bool is_root = false; + for (auto r : sch->roots) { + if (r == s->op) { + is_root = true; + break; } - if (!is_root) - sch[op].compute_inline(); } + if (!is_root) + s.compute_inline(); } } } diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 4655d17f5e6f..3c0ee43953b8 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -52,7 +52,7 @@ def test_fusion(): T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2') s = tvm.Schedule(T2.op) - tvm.schedule.AutoFuseEwise(s) + tvm.schedule.AutoFuseElemWise(s) bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) From d137400dcc5bab29e6739403b48ba7fb6bbd246e Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 9 Feb 2017 03:09:46 +0000 Subject: [PATCH 4/8] [FUSION] rename filename --- src/pass/is_elem_wise.cc | 57 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/pass/is_elem_wise.cc diff --git a/src/pass/is_elem_wise.cc b/src/pass/is_elem_wise.cc new file mode 100644 index 000000000000..dd8df2454ba7 --- /dev/null +++ b/src/pass/is_elem_wise.cc @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file elem_wise_detector.cc + */ +#include +#include +#include + +namespace tvm { +namespace ir { + +class ElemWiseDetector : public IRVisitor { + public: + explicit ElemWiseDetector(Array axis) : axis_(axis) {} + + void Visit(const NodeRef& e) final { + if (!is_elem_wise_) + return; + IRVisitor::Visit(e); + } + + void Visit_(const Call* op) final { + Array axis = op->args; + if (axis_.size() != axis.size()) { + is_elem_wise_ = false; + return; + } + + for (size_t i = 0; i < axis_.size(); ++i) { + const Variable *v1 = axis_[i]->var.as(); + const Variable *v2 = axis[i].as(); + if (!(v1 && v2) || (v1 != v2)) { + is_elem_wise_ = false; + return; + } + } + IRVisitor::Visit_(op); + } + + bool is_elem_wise_{true}; + + private: + Array axis_; +}; + + +bool IsElemWise(const NodeRef& node) { + if (const ComputeOpNode* compute = node.as()) { + ElemWiseDetector v = ElemWiseDetector(compute->axis); + v.Visit(compute->body); + return v.is_elem_wise_; + } + return false; +} + +} // namespace ir +} // namespace tvm From 8d6166c47afd60a93c5d610494dc73ac746e97e1 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 9 Feb 2017 04:03:43 +0000 Subject: [PATCH 5/8] [FUSION] fine-tune the interface --- include/tvm/ir_pass.h | 2 +- include/tvm/schedule.h | 11 ++++ include/tvm/schedule_pass.h | 6 +- src/api/api_schedule.cc | 4 +- .../{elem_wise_detector.cc => elem_wise.cc} | 6 +- src/pass/is_elem_wise.cc | 57 ------------------- .../{fusion.cc => auto_inline_elem_wise.cc} | 12 +--- .../unittest/test_schedule_schedule_ops.py | 2 +- 8 files changed, 24 insertions(+), 76 deletions(-) rename src/pass/{elem_wise_detector.cc => elem_wise.cc} (88%) delete mode 100644 src/pass/is_elem_wise.cc rename src/schedule/{fusion.cc => auto_inline_elem_wise.cc} (63%) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 10ed591a226a..e812a6c73ecf 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -171,7 +171,7 @@ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); * \brief Whether the node is element-wise. * \return whether the node is element-wise. */ -bool IsElemWise(const NodeRef& node); +bool IsElemWise(const Operation& node); } // namespace ir diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index f115dbc6f18f..68bcd2788d82 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -123,6 +123,12 @@ class Stage : public NodeRef { IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, Expr x_factor, Expr y_factor); + /*! + * \brief whether the stage has been scheduled. + * \return whether the stage has been scheduled. + */ + inline bool is_scheduled(); + // declare container type using ContainerType = StageNode; }; @@ -353,6 +359,11 @@ inline StageNode* Stage::operator->() { return static_cast(node_.get()); } +inline bool Stage::is_scheduled() { + StageNode* n = operator->(); + return !(n->relations.empty() && n->attach_type == kNone); +} + inline const ScheduleNode* Schedule::operator->() const { return static_cast(node_.get()); } diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index 0506fee52662..c4e82cde139b 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -34,11 +34,11 @@ Map InferBound(Schedule sch); Stmt ScheduleOps(Schedule s, Map dom_map); /*! - * \brief To automatically fuse the element-wise operations. + * \brief To automatically inline the element-wise operations. * - * \param s The schedule to be fused. + * \param sch The schedule to be inlined. */ -void AutoFuseElemWise(Schedule sch); +void AutoInlineElemWise(Schedule sch); } // namespace schedule } // namespace tvm diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 80284824d443..882ff94bde21 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -13,9 +13,9 @@ namespace tvm { namespace schedule { -TVM_REGISTER_API(_schedule_AutoFuseElemWise) +TVM_REGISTER_API(_schedule_AutoInlineElemWise) .set_body([](TVMArgs args, TVMRetValue* ret) { - AutoFuseElemWise(args[0]); + AutoInlineElemWise(args[0]); }); #define REGISTER_SCHEDULE_PASS1(PassName) \ diff --git a/src/pass/elem_wise_detector.cc b/src/pass/elem_wise.cc similarity index 88% rename from src/pass/elem_wise_detector.cc rename to src/pass/elem_wise.cc index dd8df2454ba7..7fb058681472 100644 --- a/src/pass/elem_wise_detector.cc +++ b/src/pass/elem_wise.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file elem_wise_detector.cc + * \file elem_wise.cc */ #include #include @@ -44,8 +44,8 @@ class ElemWiseDetector : public IRVisitor { }; -bool IsElemWise(const NodeRef& node) { - if (const ComputeOpNode* compute = node.as()) { +bool IsElemWise(const Operation& op) { + if (const ComputeOpNode* compute = op.as()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); v.Visit(compute->body); return v.is_elem_wise_; diff --git a/src/pass/is_elem_wise.cc b/src/pass/is_elem_wise.cc deleted file mode 100644 index dd8df2454ba7..000000000000 --- a/src/pass/is_elem_wise.cc +++ /dev/null @@ -1,57 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file elem_wise_detector.cc - */ -#include -#include -#include - -namespace tvm { -namespace ir { - -class ElemWiseDetector : public IRVisitor { - public: - explicit ElemWiseDetector(Array axis) : axis_(axis) {} - - void Visit(const NodeRef& e) final { - if (!is_elem_wise_) - return; - IRVisitor::Visit(e); - } - - void Visit_(const Call* op) final { - Array axis = op->args; - if (axis_.size() != axis.size()) { - is_elem_wise_ = false; - return; - } - - for (size_t i = 0; i < axis_.size(); ++i) { - const Variable *v1 = axis_[i]->var.as(); - const Variable *v2 = axis[i].as(); - if (!(v1 && v2) || (v1 != v2)) { - is_elem_wise_ = false; - return; - } - } - IRVisitor::Visit_(op); - } - - bool is_elem_wise_{true}; - - private: - Array axis_; -}; - - -bool IsElemWise(const NodeRef& node) { - if (const ComputeOpNode* compute = node.as()) { - ElemWiseDetector v = ElemWiseDetector(compute->axis); - v.Visit(compute->body); - return v.is_elem_wise_; - } - return false; -} - -} // namespace ir -} // namespace tvm diff --git a/src/schedule/fusion.cc b/src/schedule/auto_inline_elem_wise.cc similarity index 63% rename from src/schedule/fusion.cc rename to src/schedule/auto_inline_elem_wise.cc index 22a5ff69b27e..2e9668db4ee3 100644 --- a/src/schedule/fusion.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file schedule.cc + * \file auto_inline_elem_wise.cc */ #include #include @@ -8,15 +8,9 @@ namespace tvm { namespace schedule { -namespace { -inline bool is_stage_scheduled(const Stage& s) { - return !(s->relations.empty() && s->attach_type == kNone); -} -} - -void AutoFuseElemWise(Schedule sch) { +void AutoInlineElemWise(Schedule sch) { for (Stage s : sch->stages) { - if (!is_stage_scheduled(s) && ir::IsElemWise(s->op)) { + if (!s.is_scheduled() && ir::IsElemWise(s->op)) { bool is_root = false; for (auto r : sch->roots) { if (r == s->op) { diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 3c0ee43953b8..06c364dfe2b9 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -52,7 +52,7 @@ def test_fusion(): T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2') s = tvm.Schedule(T2.op) - tvm.schedule.AutoFuseElemWise(s) + tvm.schedule.AutoInlineElemWise(s) bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) From fbf9ff085536bc5c252f6b8d1c82e7e08fa745d1 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 9 Feb 2017 04:06:02 +0000 Subject: [PATCH 6/8] [FUSION] typo --- include/tvm/ir_pass.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index e812a6c73ecf..2d12d581d382 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -168,10 +168,10 @@ Array SplitHostDevice(LoweredFunc func); LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); /*! - * \brief Whether the node is element-wise. - * \return whether the node is element-wise. + * \brief Whether the operation is element-wise. + * \return whether the operation is element-wise. */ -bool IsElemWise(const Operation& node); +bool IsElemWise(const Operation& op); } // namespace ir From d547e520a38e5bab544c62b159be55f5ba85a28c Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 9 Feb 2017 05:43:54 +0000 Subject: [PATCH 7/8] move elem_wise to schedule --- include/tvm/ir_pass.h | 7 ---- include/tvm/schedule.h | 6 +-- src/pass/elem_wise.cc | 57 --------------------------- src/schedule/auto_inline_elem_wise.cc | 50 ++++++++++++++++++++++- 4 files changed, 52 insertions(+), 68 deletions(-) delete mode 100644 src/pass/elem_wise.cc diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 2d12d581d382..9e3e1b0a1d53 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -167,13 +167,6 @@ Array SplitHostDevice(LoweredFunc func); */ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); -/*! - * \brief Whether the operation is element-wise. - * \return whether the operation is element-wise. - */ -bool IsElemWise(const Operation& op); - - } // namespace ir } // namespace tvm diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 68bcd2788d82..a7cd58c96524 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -127,7 +127,7 @@ class Stage : public NodeRef { * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. */ - inline bool is_scheduled(); + inline bool is_scheduled() const; // declare container type using ContainerType = StageNode; @@ -359,8 +359,8 @@ inline StageNode* Stage::operator->() { return static_cast(node_.get()); } -inline bool Stage::is_scheduled() { - StageNode* n = operator->(); +inline bool Stage::is_scheduled() const { + const StageNode* n = operator->(); return !(n->relations.empty() && n->attach_type == kNone); } diff --git a/src/pass/elem_wise.cc b/src/pass/elem_wise.cc deleted file mode 100644 index 7fb058681472..000000000000 --- a/src/pass/elem_wise.cc +++ /dev/null @@ -1,57 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file elem_wise.cc - */ -#include -#include -#include - -namespace tvm { -namespace ir { - -class ElemWiseDetector : public IRVisitor { - public: - explicit ElemWiseDetector(Array axis) : axis_(axis) {} - - void Visit(const NodeRef& e) final { - if (!is_elem_wise_) - return; - IRVisitor::Visit(e); - } - - void Visit_(const Call* op) final { - Array axis = op->args; - if (axis_.size() != axis.size()) { - is_elem_wise_ = false; - return; - } - - for (size_t i = 0; i < axis_.size(); ++i) { - const Variable *v1 = axis_[i]->var.as(); - const Variable *v2 = axis[i].as(); - if (!(v1 && v2) || (v1 != v2)) { - is_elem_wise_ = false; - return; - } - } - IRVisitor::Visit_(op); - } - - bool is_elem_wise_{true}; - - private: - Array axis_; -}; - - -bool IsElemWise(const Operation& op) { - if (const ComputeOpNode* compute = op.as()) { - ElemWiseDetector v = ElemWiseDetector(compute->axis); - v.Visit(compute->body); - return v.is_elem_wise_; - } - return false; -} - -} // namespace ir -} // namespace tvm diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index 2e9668db4ee3..66816c955acb 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -3,9 +3,57 @@ * \file auto_inline_elem_wise.cc */ #include -#include +#include namespace tvm { +namespace ir { + +class ElemWiseDetector : public IRVisitor { + public: + explicit ElemWiseDetector(Array axis) : axis_(axis) {} + + void Visit(const NodeRef& e) final { + if (!is_elem_wise_) return; + IRVisitor::Visit(e); + } + + void Visit_(const Call* op) final { + Array axis = op->args; + if (axis_.size() != axis.size()) { + is_elem_wise_ = false; + return; + } + + for (size_t i = 0; i < axis_.size(); ++i) { + // const Variable *v1 = axis_[i]->var.as(); + // const Variable *v2 = axis[i].as(); + if (!axis[i].same_as(axis_[i]->var)) { + // if (!(v1 && v2) || (v1 != v2)) { + is_elem_wise_ = false; + return; + } + } + IRVisitor::Visit_(op); + } + + bool is_elem_wise_{true}; + + private: + Array axis_; +}; + + +bool IsElemWise(const Operation& op) { + if (const ComputeOpNode* compute = op.as()) { + ElemWiseDetector v = ElemWiseDetector(compute->axis); + v.Visit(compute->body); + return v.is_elem_wise_; + } + return false; +} + +} // namespace ir + namespace schedule { void AutoInlineElemWise(Schedule sch) { From ae4de6f0115f76af2b58882e485106d3514305ad Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 9 Feb 2017 05:55:43 +0000 Subject: [PATCH 8/8] rename test function --- tests/python/unittest/test_schedule_schedule_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 06c364dfe2b9..9689a1c34fc4 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -42,7 +42,7 @@ def test_schedule2(): stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) -def test_fusion(): +def test_auto_inline(): m = tvm.Var('m') n = tvm.Var('n') A = tvm.placeholder((m, n), name='A') @@ -62,4 +62,4 @@ def test_fusion(): test_schedule0() test_schedule1() test_schedule2() - test_fusion() + test_auto_inline()