From 48b6d2db55b5f92d1c4ec490bc9cff6b57df7a88 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 30 Mar 2019 23:30:13 -0700 Subject: [PATCH 1/7] [relay][bugfix] fuse injective to elemwise and broadcast --- src/relay/pass/fuse_ops.cc | 1 + tests/python/relay/test_pass_annotation.py | 3 ++- tests/python/relay/test_pass_fuse_ops.py | 6 ++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 55d609872929..f49cb24f9754 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -719,6 +719,7 @@ class GraphPartitioner { } else { return (kind <= kBroadcast || kind == kCommReduce || + kind == kInjective || kind == kOutEWiseFusable); } }; diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 04081e06735b..fe71b21509e4 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -278,7 +278,6 @@ def test_runtime(target, device, func, fallback_device=None, graph_json = json.loads(graph) if "device_index" in graph_json["attrs"]: device_index = graph_json["attrs"]["device_index"][1] - assert device_index == expected_index mod = graph_runtime.create(graph, lib, contexts) mod.set_input(**params) mod.run() @@ -291,6 +290,7 @@ def test_fuse_log_add(device, tgt): target = {"cpu": "llvm", device: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(device) + dev_ty = dev_ctx.device_type def annotated(): add = relay.add(x, y) @@ -372,6 +372,7 @@ def test_fallback_exp(device, tgt): target = {"cpu": "llvm", device: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(device) + dev_ty = dev_ctx.device_type def annotated(): add = relay.add(x, y) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index bdffdf7c129f..5057ea43b729 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -23,13 +23,15 @@ def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) - return relay.Function([x], z) + w = relay.squeeze(z) + return relay.Function([x], w) def expected(): x = relay.var("p", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) - f1 = relay.Function([x], z) + w = relay.squeeze(z) + f1 = relay.Function([x], w) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) From bd4a8de2507495275b7f2b8e67629c3571911625 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sun, 31 Mar 2019 10:56:16 -0700 Subject: [PATCH 2/7] enhance fusion for prarllel injectiveOD --- src/relay/pass/fuse_ops.cc | 4 ++- tests/python/relay/test_pass_fuse_ops.py | 33 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index f49cb24f9754..fc7aad6ce515 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -715,7 +715,9 @@ class GraphPartitioner { // The final terminal node can already be fused to a OutEWiseFusable group. auto fcond = [](OpPatternKind kind, bool is_sink) { if (!is_sink) { - return kind <= kBroadcast; + // Elemwise, broadcast, and injective ops on the parallel branches + // are allowed be fused to the elemwise/broadcast master. + return kind <= kInjective; } else { return (kind <= kBroadcast || kind == kCommReduce || diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 5057ea43b729..6d6781046a10 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -505,6 +505,38 @@ def expected(dshape): assert relay.ir_pass.alpha_equal(zz, after) +def test_fuse_parallel_injective(): + """Test fusing parallel injective ops to an elemwise op.""" + def before(): + x = relay.var("x", shape=(10, 20)) + y = relay.add(x, relay.const(1, "float32")) + z = relay.squeeze(y) + u = relay.transpose(y, axes=[0, 1]) + w = relay.left_shift(z, u) + return relay.Function([x], w) + + def expected(): + x = relay.var("p", shape=(10, 20)) + y = relay.add(x, relay.const(1, "float32")) + z = relay.squeeze(y) + u = relay.transpose(y, axes=[0, 1]) + w = relay.left_shift(z, u) + f1 = relay.Function([x], w) + x = relay.var("x", shape=(10, 20)) + y = relay.Call(f1, [x]) + return relay.Function([x], y) + + z = before() + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=0) + assert not relay.ir_pass.free_vars(zz) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + assert not relay.ir_pass.free_vars(zz) + after = relay.ir_pass.infer_type(expected()) + assert relay.ir_pass.alpha_equal(zz, after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -517,3 +549,4 @@ def expected(dshape): test_tuple_intermediate() test_tuple_consecutive() test_inception_like() + test_fuse_parallel_injective() From 0dd1162993dd7c8bae3e2f1aff5cb01cf837d88e Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 4 Apr 2019 18:17:02 -0700 Subject: [PATCH 3/7] check if tensor in schedule --- include/tvm/schedule.h | 16 ++++++++++++++++ src/relay/backend/compile_engine.cc | 4 +++- src/schedule/schedule_lang.cc | 4 ++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 9a556b6ce960..61447573f76f 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -551,6 +551,22 @@ class ScheduleNode : public Node { /*! \brief Invalidate temp cache. */ void InvalidateCache(); + /*! + * \brief Check if the schedule contains an Operation. + * \param op The candidate Operation. + * \return true if the schedule has the Operation. Otherwise, false. + */ + EXPORT bool Contain(const Operation& op) const; + + /*! + * \brief Check if the schedule contains a Tensor. + * \param op The candidate tensor. + * \return true if the schedule has the tensor. Otherwise, false. + */ + EXPORT bool Contain(const Tensor& tensor) const { + return Contain(tensor->op); + } + /*! * \brief Create a schedule for array of ops(and their dependencies). * \param ops The ops to be scheduled. diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 43515105bd94..4b5842c36020 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -127,7 +127,9 @@ class ScheduleGetter : schedule = fschedule[master_op_](master_attrs_, tensor_outs, target_); for (const auto& scalar : scalars_) { - schedule[scalar].compute_inline(); + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } } } return std::make_pair(schedule, cfunc); diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index ffee804198b6..e1cb4c5f9bdc 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -712,6 +712,10 @@ void ScheduleNode::InitCache() { CHECK_EQ(op2stage_cache_.size(), stages.size()); } +bool ScheduleNode::Contain(const Operation& op) const { + return stage_map.find(op) != stage_map.end(); +} + Schedule ScheduleNode::make(Array ops) { auto n = make_node(); Schedule sch(n); From e5f6d47b712db081e177faf1749fd6d33b10ca7c Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 5 Apr 2019 00:07:03 -0700 Subject: [PATCH 4/7] fix codegen --- src/codegen/llvm/codegen_llvm.cc | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 7946f906125f..291cac98ab34 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -868,7 +868,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->a.type().is_int() || op->a.type().is_uint()) { + + if (op->a.type().is_handle() && op->b.type().is_handle()) { + return builder_->CreateICmpEQ(a, b); + } else if (op->a.type().is_handle() || op->b.type().is_handle()) { + LOG(FATAL) << "Both or none of the operands should be pointers." + << "\n"; + return nullptr; + } else if (op->a.type().is_int() || op->a.type().is_uint()) { return builder_->CreateICmpEQ(a, b); } else { return builder_->CreateFCmpOEQ(a, b); @@ -878,7 +885,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->a.type().is_int() || op->a.type().is_uint()) { + + if (op->a.type().is_handle() && op->b.type().is_handle()) { + return builder_->CreateICmpEQ(a, b); + } else if (op->a.type().is_handle() || op->b.type().is_handle()) { + LOG(FATAL) << "Both or none of the operands should be pointers." + << "\n"; + return nullptr; + } else if (op->a.type().is_int() || op->a.type().is_uint()) { return builder_->CreateICmpNE(a, b); } else { return builder_->CreateFCmpONE(a, b); From ac5cbb377ebece05345cdb655e0b115483af5de0 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 5 Apr 2019 00:11:48 -0700 Subject: [PATCH 5/7] fix lint --- src/codegen/llvm/codegen_llvm.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 291cac98ab34..46fb0188f7c6 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -870,9 +870,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* b = MakeValue(op->b); if (op->a.type().is_handle() && op->b.type().is_handle()) { - return builder_->CreateICmpEQ(a, b); + return builder_->CreateICmpEQ(a, b); } else if (op->a.type().is_handle() || op->b.type().is_handle()) { - LOG(FATAL) << "Both or none of the operands should be pointers." + LOG(FATAL) << "Both or none of the operands should be pointers." << "\n"; return nullptr; } else if (op->a.type().is_int() || op->a.type().is_uint()) { @@ -887,9 +887,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) { llvm::Value* b = MakeValue(op->b); if (op->a.type().is_handle() && op->b.type().is_handle()) { - return builder_->CreateICmpEQ(a, b); + return builder_->CreateICmpEQ(a, b); } else if (op->a.type().is_handle() || op->b.type().is_handle()) { - LOG(FATAL) << "Both or none of the operands should be pointers." + LOG(FATAL) << "Both or none of the operands should be pointers." << "\n"; return nullptr; } else if (op->a.type().is_int() || op->a.type().is_uint()) { From 05878a104f82b0aef6b1977c8535b418b9cee7f4 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 1 May 2019 01:05:41 +0000 Subject: [PATCH 6/7] update --- src/codegen/llvm/codegen_llvm.cc | 18 ++---------------- tests/python/relay/test_pass_annotation.py | 3 +-- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 46fb0188f7c6..7946f906125f 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -868,14 +868,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - - if (op->a.type().is_handle() && op->b.type().is_handle()) { - return builder_->CreateICmpEQ(a, b); - } else if (op->a.type().is_handle() || op->b.type().is_handle()) { - LOG(FATAL) << "Both or none of the operands should be pointers." - << "\n"; - return nullptr; - } else if (op->a.type().is_int() || op->a.type().is_uint()) { + if (op->a.type().is_int() || op->a.type().is_uint()) { return builder_->CreateICmpEQ(a, b); } else { return builder_->CreateFCmpOEQ(a, b); @@ -885,14 +878,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQ* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const NE* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - - if (op->a.type().is_handle() && op->b.type().is_handle()) { - return builder_->CreateICmpEQ(a, b); - } else if (op->a.type().is_handle() || op->b.type().is_handle()) { - LOG(FATAL) << "Both or none of the operands should be pointers." - << "\n"; - return nullptr; - } else if (op->a.type().is_int() || op->a.type().is_uint()) { + if (op->a.type().is_int() || op->a.type().is_uint()) { return builder_->CreateICmpNE(a, b); } else { return builder_->CreateFCmpONE(a, b); diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index fe71b21509e4..04081e06735b 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -278,6 +278,7 @@ def test_runtime(target, device, func, fallback_device=None, graph_json = json.loads(graph) if "device_index" in graph_json["attrs"]: device_index = graph_json["attrs"]["device_index"][1] + assert device_index == expected_index mod = graph_runtime.create(graph, lib, contexts) mod.set_input(**params) mod.run() @@ -290,7 +291,6 @@ def test_fuse_log_add(device, tgt): target = {"cpu": "llvm", device: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(device) - dev_ty = dev_ctx.device_type def annotated(): add = relay.add(x, y) @@ -372,7 +372,6 @@ def test_fallback_exp(device, tgt): target = {"cpu": "llvm", device: tgt} cpu_ctx = fallback_device dev_ctx = tvm.context(device) - dev_ty = dev_ctx.device_type def annotated(): add = relay.add(x, y) From 0cff7f228bc8a000f1f5efb3408556f6b5989780 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 1 May 2019 01:46:45 +0000 Subject: [PATCH 7/7] lint --- include/tvm/schedule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 61447573f76f..6c2a759db471 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -560,7 +560,7 @@ class ScheduleNode : public Node { /*! * \brief Check if the schedule contains a Tensor. - * \param op The candidate tensor. + * \param tensor The candidate tensor. * \return true if the schedule has the tensor. Otherwise, false. */ EXPORT bool Contain(const Tensor& tensor) const {