From 7fdad4f7e3d3e7ef0d18bea5d2aea7339e2f97b0 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 18 Apr 2020 01:45:43 +0000 Subject: [PATCH] fix fuse over functions that are handled by external codegen --- src/relay/backend/vm/compiler.cc | 14 +++++++------- src/relay/transforms/fuse_ops.cc | 3 +++ tests/python/relay/test_pass_partition_graph.py | 1 - 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7e2d43e7b35d..cb61a61470c7 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::InlinePrimitives()); - // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); - // Compute away possibly introduced constant computation. - pass_seqs.push_back(transform::FoldConstant()); - // Fuse the shape functions. - pass_seqs.push_back(transform::FuseOps()); - // Inline the functions that are lifted to the module scope. We perform this // pass after all other optimization passes but before the memory allocation // pass. This is because memory allocation pass will insert `invoke_tvm_op` @@ -938,6 +931,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); + // Manifest the allocations. + pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + // Manifest the allocations needed for the shape functions. pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f646042962f0..e37b44c12dca 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -199,6 +199,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Post order tree void VisitExpr_(const FunctionNode* op) final { + // Skip the function that should be handled by external codegen. + if (op->GetAttr(attr::kCompiler).defined()) return; + for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 3261ccd0d7c9..14d57a92f106 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -457,7 +457,6 @@ def test_extern_dnnl_mobilenet(): mod, params = relay.testing.mobilenet.get_workload( batch_size=1, dtype='float32') - mod["main"] = bind_params_by_name(mod["main"], params) mod = transform.AnnotateTarget(["dnnl"])(mod) mod = transform.MergeCompilerRegions()(mod) mod = transform.PartitionGraph()(mod)