From d478873511602d4ff34f2fd710cea8b7c877ff23 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Sep 2019 09:47:06 -0700 Subject: [PATCH 1/6] [ARITH] Switch indexdiv/mod to floor --- python/tvm/expr.py | 9 +++------ src/lang/buffer.cc | 4 ++-- src/lang/expr_operator.cc | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 5b7c60d819bd..733f57a68c56 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -92,16 +92,13 @@ def __rtruediv__(self, other): return _generic.divide(other, self) def __floordiv__(self, other): - # return _generic.floordiv(self, other) - return _generic.divide(self, other) + return _generic.floordiv(self, other) def __rfloordiv__(self, other): - # return _generic.floordiv(other, self) - return _generic.divide(other, self) + return _generic.floordiv(other, self) def __mod__(self, other): - raise div_ambiguity_error() - # return _make._OpMod(self, other) + return _make._OpFloorMod(self, other) def __neg__(self): neg_one = _api_internal._const(-1, self.dtype) diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 206056bf889b..689b291ae2ed 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -32,8 +32,8 @@ namespace tvm { // TODO(tqchen): change to floormod/div -using IndexMod = ir::Mod; -using IndexDiv = ir::Div; +using IndexMod = ir::FloorMod; +using IndexDiv = ir::FloorDiv; Array SimplifyArray(Array array) { for (size_t i = 0; i < array.size(); ++i) { diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 46a0737eab7e..9c9100b1902e 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -208,11 +208,11 @@ Expr operator%(Expr a, Expr b) { // TODO(tqchen): switch to floordiv Expr indexdiv(Expr a, Expr b) { - return truncdiv(a, b); + return floordiv(a, b); } Expr indexmod(Expr a, Expr b) { - return truncmod(a, b); + return floormod(a, b); } Expr floordiv(Expr a, Expr b) { From 3062f5299de35c0169f3ac0ed31ff0cbf289a0fc Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 28 Sep 2019 20:16:38 -0700 Subject: [PATCH 2/6] additional attrs support --- src/lang/attr_functor.h | 13 +++++++++++-- src/lang/attrs.cc | 8 ++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 249ce523a3cc..995dfb392e87 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -87,6 +87,8 @@ class AttrFunctor { virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -119,6 +121,9 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(Sub); ATTR_FUNCTOR_DISPATCH(Mul); ATTR_FUNCTOR_DISPATCH(Div); + ATTR_FUNCTOR_DISPATCH(Mod); + ATTR_FUNCTOR_DISPATCH(FloorDiv); + ATTR_FUNCTOR_DISPATCH(FloorMod); ATTR_FUNCTOR_DISPATCH(Min); ATTR_FUNCTOR_DISPATCH(Max); ATTR_FUNCTOR_DISPATCH(GE); @@ -160,6 +165,8 @@ class AttrsEqualHandler : bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::FloorDiv* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::FloorMod* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::GE* lhs, const NodeRef& other) final; @@ -201,6 +208,8 @@ class AttrsHashHandler : size_t VisitAttr_(const ir::Mul* op) final; size_t VisitAttr_(const ir::Div* op) final; size_t VisitAttr_(const ir::Mod* op) final; + size_t VisitAttr_(const ir::FloorDiv* op) final; + size_t VisitAttr_(const ir::FloorMod* op) final; size_t VisitAttr_(const ir::Min* op) final; size_t VisitAttr_(const ir::Max* op) final; size_t VisitAttr_(const ir::GE* op) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index ec2fd742ba14..c5b14ac577ec 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -154,6 +154,8 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); TVM_DEFINE_ATTRS_BINOP_EQUAL(Div); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); +TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv); +TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod); TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); TVM_DEFINE_ATTRS_BINOP_EQUAL(GE); @@ -266,6 +268,8 @@ TVM_DEFINE_ATTRS_BINOP_HASH(Sub); TVM_DEFINE_ATTRS_BINOP_HASH(Mul); TVM_DEFINE_ATTRS_BINOP_HASH(Div); TVM_DEFINE_ATTRS_BINOP_HASH(Mod); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod); TVM_DEFINE_ATTRS_BINOP_HASH(Max); TVM_DEFINE_ATTRS_BINOP_HASH(Min); TVM_DEFINE_ATTRS_BINOP_HASH(GE); From 1ecb4c79403fa142480ca856ad56bb2fc6f3185b Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 29 Sep 2019 09:27:02 -0700 Subject: [PATCH 3/6] Fix vm tests --- tests/python/unittest/test_codegen_device.py | 2 ++ tests/python/unittest/test_codegen_vm_basic.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 6cb424c8a5eb..983cf0a231f9 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -48,6 +48,8 @@ def test_add_pipeline(): stmt = tvm.ir_pass.Simplify(stmt) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True) fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] + # lower the floordiv + fsplits = [tvm.ir_pass.LowerIntrin(x, "generic") for x in fsplits] fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) def check_target(device, host="stackvm"): diff --git a/tests/python/unittest/test_codegen_vm_basic.py b/tests/python/unittest/test_codegen_vm_basic.py index a9b382f1fd61..42ec8622f40d 100644 --- a/tests/python/unittest/test_codegen_vm_basic.py +++ b/tests/python/unittest/test_codegen_vm_basic.py @@ -37,6 +37,7 @@ def tvm_call_back_get_shape(shape0): stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) + fapi = tvm.ir_pass.LowerIntrin(fapi, "generic") run_jit(fapi, lambda f: f(a)) From 0cefa6d7224ba05230721124c283499bc6ab8c5b Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 29 Sep 2019 16:06:10 -0700 Subject: [PATCH 4/6] stackvm patch --- src/pass/lower_intrin.cc | 10 ++++++++-- tests/python/unittest/test_codegen_device.py | 4 ++-- tests/python/unittest/test_codegen_vm_basic.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index bbc3c572ca7e..96c277495767 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -46,6 +46,9 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { patterns_.push_back("tvm.intrin.rule." + starget + "."); patterns_.push_back("tvm.intrin.rule.default."); fma_ = runtime::Registry::Get(patterns_[0] + "fma"); + if (target == "stackvm") { + support_bitwise_op_ = false; + } } Expr Mutate_(const Call* op, const Expr& e) final { @@ -76,7 +79,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { const DataType& dtype = op->type; CHECK(dtype.is_int() || !dtype.is_uint()); - if (is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && + is_const_power_of_two_integer(op->b, &shift)) { // lower to right shift if possible. return op->a >> make_const(dtype, shift); } @@ -122,7 +126,8 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { const DataType& dtype = op->type; CHECK(dtype.is_int() || !dtype.is_uint()); - if (is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && + is_const_power_of_two_integer(op->b, &shift)) { // lower to masking if possible. int64_t mask = ( static_cast(1) << static_cast(shift)) - 1; @@ -268,6 +273,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // patterns std::vector patterns_; const PackedFunc* fma_{nullptr}; + bool support_bitwise_op_{true}; }; Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { diff --git a/tests/python/unittest/test_codegen_device.py b/tests/python/unittest/test_codegen_device.py index 983cf0a231f9..45ecf9539337 100644 --- a/tests/python/unittest/test_codegen_device.py +++ b/tests/python/unittest/test_codegen_device.py @@ -48,8 +48,8 @@ def test_add_pipeline(): stmt = tvm.ir_pass.Simplify(stmt) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True) fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)] - # lower the floordiv - fsplits = [tvm.ir_pass.LowerIntrin(x, "generic") for x in fsplits] + # lower the floordiv(use stackvm rules so it works for all targets) + fsplits = [tvm.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits] fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0]) def check_target(device, host="stackvm"): diff --git a/tests/python/unittest/test_codegen_vm_basic.py b/tests/python/unittest/test_codegen_vm_basic.py index 42ec8622f40d..7ff217728034 100644 --- a/tests/python/unittest/test_codegen_vm_basic.py +++ b/tests/python/unittest/test_codegen_vm_basic.py @@ -37,7 +37,7 @@ def tvm_call_back_get_shape(shape0): stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.ir_pass.LowerTVMBuiltin(fapi) - fapi = tvm.ir_pass.LowerIntrin(fapi, "generic") + fapi = tvm.ir_pass.LowerIntrin(fapi, "stackvm") run_jit(fapi, lambda f: f(a)) From 82508463a754abcbfcc684962f67dcd660d70049 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 29 Sep 2019 16:51:03 -0700 Subject: [PATCH 5/6] stackvm --- src/pass/lower_intrin.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 96c277495767..3935d23cce0c 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -97,7 +97,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // condition on b >= 0. // truncmod(a, b) < 0 will implies ceildiv, // So we need to correct these cases. - if (dtype == Int(32) || dtype == Int(64)) { + if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { @@ -145,7 +145,7 @@ class IntrinInjecter : public arith::IRMutatorWithAnalyzer { // mod(a, b) < 0 will imply we are doing ceildiv, // So we need to correct these cases. Expr rmod = truncmod(op->a, op->b); - if (dtype == Int(32) || dtype == Int(64)) { + if ((dtype == Int(32) || dtype == Int(64)) && support_bitwise_op_) { // (rmod >> shift) & b // -> (rmod >= 0 ? 0: -1) & b // -> rmod >= 0 ? 0 : b From 94ef2095e37579efe2435cc2d2e281663bdcc25e Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 29 Sep 2019 19:42:04 -0700 Subject: [PATCH 6/6] fix nms --- topi/python/topi/cuda/nms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 33fc7249802b..d032527ec273 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -185,7 +185,7 @@ def get_valid_counts_scan(data, partial_in, partial): ib.scope_attr(bx, "thread_extent", nthread_bx) var = tvm.make.node("FloatImm", dtype="float32", value=2) new_range = num_anchors // elem_per_thread + 1 - iteration = log(cast(new_range, "float32")) // math.log(2) + iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32") # Scan: Kogge-Stone adder with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): with ib.for_range(0, iteration) as k: