From 1c1ba2ad660182597f62f2326a68a8e0be95e191 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 15 Jan 2017 17:14:32 -0800 Subject: [PATCH 1/2] [PASS] Export simplify and equal to python --- include/tvm/ir_pass.h | 4 ++++ src/c_api/c_api_pass.cc | 21 +++++++++++++++++++++ tests/python/test_pass_basic.py | 10 ++++++++++ 3 files changed, 35 insertions(+) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index d4456ed74cd4..6fb7236157b9 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -9,6 +9,8 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ +#include +#include #include #include #include @@ -19,6 +21,8 @@ namespace tvm { namespace ir { +using Halide::Internal::equal; +using Halide::Internal::simplify; /*! * \brief Schedule s' dependent operations. diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index e05f696bd35b..c667069ce189 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -13,6 +13,27 @@ namespace ir { using ArgStack = const std::vector; using RetValue = APIVariantValue; +TVM_REGISTER_API(_pass_Simplify) +.set_body([](const ArgStack& args, RetValue *ret) { + CHECK(args.at(0).type_id == kNodeHandle); + if (dynamic_cast(args.at(0).sptr.get())) { + *ret = simplify(args.at(0).operator Expr()); + } else { + *ret = simplify(args.at(0).operator Stmt()); + } + }); + +TVM_REGISTER_API(_pass_equal) +.set_body([](const ArgStack& args, RetValue *ret) { + CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(1).type_id == kNodeHandle); + if (dynamic_cast(args.at(0).sptr.get())) { + *ret = equal(args.at(0).operator Expr(), args.at(1).operator Expr()); + } else { + *ret = equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); + } + }); + // make from two arguments #define REGISTER_PASS1(PassName) \ TVM_REGISTER_API(_pass_## PassName) \ diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py index 23262f1cc9f1..25ff5ea717c2 100644 --- a/tests/python/test_pass_basic.py +++ b/tests/python/test_pass_basic.py @@ -1,5 +1,15 @@ import tvm +def test_simplify(): + x = tvm.Var('x') + e1 = tvm.ir_pass.Simplify(x + 2 + 1) + assert(tvm.ir_pass.equal(e1, x + 3)) + e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x) + assert(tvm.ir_pass.equal(e2, x * 8)) + e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) + assert(tvm.ir_pass.equal(e3, tvm.make.Mod(x, 3))) + + def test_verify_ssa(): x = tvm.Var('x') y = tvm.Var() From 7988c287b048f239a02e6c073e3da70b95af82a6 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 15 Jan 2017 21:45:12 -0800 Subject: [PATCH 2/2] fix naming convention --- include/tvm/ir_pass.h | 17 +++++++++++++++-- src/c_api/c_api_pass.cc | 10 +++++----- tests/python/test_pass_basic.py | 6 +++--- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 6fb7236157b9..a45bbbb91fd8 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -21,8 +21,21 @@ namespace tvm { namespace ir { -using Halide::Internal::equal; -using Halide::Internal::simplify; +inline bool Equal(Expr a, Expr b) { + return Halide::Internal::equal(a, b); +} + +inline bool Equal(Stmt a, Stmt b) { + return Halide::Internal::equal(a, b); +} + +inline Expr Simplify(Expr a) { + return Halide::Internal::simplify(a); +} + +inline Stmt Simplify(Stmt a) { + return Halide::Internal::simplify(a); +} /*! * \brief Schedule s' dependent operations. diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index c667069ce189..10ffe95f653d 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -17,20 +17,20 @@ TVM_REGISTER_API(_pass_Simplify) .set_body([](const ArgStack& args, RetValue *ret) { CHECK(args.at(0).type_id == kNodeHandle); if (dynamic_cast(args.at(0).sptr.get())) { - *ret = simplify(args.at(0).operator Expr()); + *ret = Simplify(args.at(0).operator Expr()); } else { - *ret = simplify(args.at(0).operator Stmt()); + *ret = Simplify(args.at(0).operator Stmt()); } }); -TVM_REGISTER_API(_pass_equal) +TVM_REGISTER_API(_pass_Equal) .set_body([](const ArgStack& args, RetValue *ret) { CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(1).type_id == kNodeHandle); if (dynamic_cast(args.at(0).sptr.get())) { - *ret = equal(args.at(0).operator Expr(), args.at(1).operator Expr()); + *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); } else { - *ret = equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); + *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); } }); diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py index 25ff5ea717c2..ebffc58805f3 100644 --- a/tests/python/test_pass_basic.py +++ b/tests/python/test_pass_basic.py @@ -3,11 +3,11 @@ def test_simplify(): x = tvm.Var('x') e1 = tvm.ir_pass.Simplify(x + 2 + 1) - assert(tvm.ir_pass.equal(e1, x + 3)) + assert(tvm.ir_pass.Equal(e1, x + 3)) e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x) - assert(tvm.ir_pass.equal(e2, x * 8)) + assert(tvm.ir_pass.Equal(e2, x * 8)) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) - assert(tvm.ir_pass.equal(e3, tvm.make.Mod(x, 3))) + assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) def test_verify_ssa():