From 9c47ede077b84fdf917161fabf0f03a11ec6c889 Mon Sep 17 00:00:00 2001 From: Yutetsu TAKATSUKASA Date: Wed, 20 Mar 2019 15:30:56 +0900 Subject: [PATCH 1/4] Add a unittest to show the current behavior of DetectLinearEquation() with empty vars. --- .../unittest/test_arith_detect_linear_equation.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 2b0f327b65b2..9b1f92903881 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -20,6 +20,11 @@ def test_basic(): m = tvm.arith.DetectLinearEquation(b * 7, [a]) assert m[0].value == 0 + m = tvm.arith.DetectLinearEquation(b * 7, []) + assert len(m) == 2 + assert m[0].value == 1 + assert tvm.ir_pass.Simplify(m[1] - b * 7).value == 0 + def test_multivariate(): v = [tvm.var("v%d" % i) for i in range(4)] b = tvm.var("b") @@ -42,6 +47,11 @@ def test_multivariate(): assert(m[0].value == 0) assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + m = tvm.arith.DetectLinearEquation((v[0] - v[1]), []) + assert(len(m) == 2) + assert(m[0].value == 1) + assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + if __name__ == "__main__": test_basic() test_multivariate() From 1ae02d360ba93bed0658a23ce25de971fe4e2bff Mon Sep 17 00:00:00 2001 From: Yutetsu TAKATSUKASA Date: Wed, 20 Mar 2019 15:32:19 +0900 Subject: [PATCH 2/4] Now DetectLinearEquation() returns the input expr as-is when empty vars is passed. --- src/arithmetic/detect_linear_equation.cc | 30 ++++++++----------- .../test_arith_detect_linear_equation.py | 10 +++---- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 6f4d3cfb53bb..e7bc7e74b675 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -127,25 +127,21 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { Expr base = e; Array coeff; - if (0 == vars.size()) { - coeff.push_back(make_const(Int(32), 1)); - } else { - for (Var v : vars) { - LinearEqEntry ret; - if (!LinearEqDetector(v).Detect(base, &ret)) { - return Array(); - } - coeff.push_back(ret.coeff); - base = std::move(ret.base); + for (Var v : vars) { + LinearEqEntry ret; + if (!LinearEqDetector(v).Detect(base, &ret)) { + return Array(); } + coeff.push_back(ret.coeff); + base = std::move(ret.base); + } - std::unordered_set vset; - for (size_t i = vars.size(); i != 1; --i) { - vset.insert(vars[i - 1].get()); - // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset)) { - return Array(); - } + std::unordered_set vset; + for (size_t i = vars.size(); i > 1; --i) { + vset.insert(vars[i - 1].get()); + // The previous coeff contains the variable + if (ExprUseVar(coeff[i - 2], vset)) { + return Array(); } } coeff.push_back(base); diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 9b1f92903881..33e266684f09 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -21,9 +21,8 @@ def test_basic(): assert m[0].value == 0 m = tvm.arith.DetectLinearEquation(b * 7, []) - assert len(m) == 2 - assert m[0].value == 1 - assert tvm.ir_pass.Simplify(m[1] - b * 7).value == 0 + assert len(m) == 1 + assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0 def test_multivariate(): v = [tvm.var("v%d" % i) for i in range(4)] @@ -48,9 +47,8 @@ def test_multivariate(): assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) m = tvm.arith.DetectLinearEquation((v[0] - v[1]), []) - assert(len(m) == 2) - assert(m[0].value == 1) - assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + assert(len(m) == 1) + assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0) if __name__ == "__main__": test_basic() From cacbbbaf50da3e3db2ab766933dcf76f3eb06079 Mon Sep 17 00:00:00 2001 From: Yutetsu TAKATSUKASA Date: Wed, 20 Mar 2019 15:33:58 +0900 Subject: [PATCH 3/4] update inject_copy_intrin.cc to catch up with DetectLinearEquation(). The result of InjectCopyIntrin() is as same as before. --- src/pass/inject_copy_intrin.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 7ca1d133bd2d..8fac88a44d85 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -39,7 +39,6 @@ class CopyIntrinInjector : public IRMutator { bool MatchCopyPattern(Stmt stmt, Stmt *out) { using namespace arith; Stmt body = stmt; - bool is_single_point_copy = false; // strip the loops std::vector loops; @@ -60,7 +59,6 @@ class CopyIntrinInjector : public IRMutator { const Cast* cast = store->value.as(); const Load* load = store->value.as(); if (0 == loops.size()) { - is_single_point_copy = true; CHECK(!has_cond); } // for now only support true condition matching @@ -83,9 +81,8 @@ class CopyIntrinInjector : public IRMutator { arith::DetectLinearEquation(load->index, loop_vars); if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; - auto loop_var_size = loop_vars.size(); - if (is_single_point_copy) { - loop_var_size = 1; + const size_t loop_var_size = loop_vars.size(); + if (loop_var_size == 0) { dst_shape.push_back(make_const(Int(32), 1)); } else { for (const For* op : loops) { @@ -132,6 +129,10 @@ class CopyIntrinInjector : public IRMutator { CHECK_EQ(load_strides.size(), loop_var_size + 1); Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); + if(loop_var_size == 0) { + src_strides.push_back(make_const(Int(32), 1)); + dst_strides.push_back(make_const(Int(32), 1)); + } Buffer dst = BufferNode::make( Var(store->buffer_var.node_), store->value.type(), From 6ee59d2b572d091a3d8b99d7c1d06047517547d2 Mon Sep 17 00:00:00 2001 From: Yutetsu TAKATSUKASA Date: Wed, 20 Mar 2019 16:11:41 +0900 Subject: [PATCH 4/4] comply the coding standard --- src/pass/inject_copy_intrin.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 8fac88a44d85..7dcfcfdae239 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -129,7 +129,7 @@ class CopyIntrinInjector : public IRMutator { CHECK_EQ(load_strides.size(), loop_var_size + 1); Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); - if(loop_var_size == 0) { + if (loop_var_size == 0) { src_strides.push_back(make_const(Int(32), 1)); dst_strides.push_back(make_const(Int(32), 1)); }