From 104793de65e7142207e80797bc353445330c1135 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 15 Oct 2020 17:08:26 +0100 Subject: [PATCH 1/6] [BYOC] Allow custom codegens to register their own constant updater Currently, all codegens using BYOC must make use of the default ConstantUpdater pass. However, certain codegens, like Ethos-N, don't want to store any constants in metadata module. This provides an interface (via a global) to register a custom constant updating method and assigns a 'null' updater for the Ethos-N codegen. Change-Id: Ibd71d3091f992362eeede5d894eedb373b2dbc8f --- .../backend/contrib/codegen_c/codegen.cc | 28 +++++++ .../backend/contrib/ethosn/codegen_ethosn.h | 4 + src/relay/backend/graph_runtime_codegen.cc | 24 +++++- src/relay/backend/vm/compiler.cc | 23 +++++- .../test_ethosn/test_constant_duplication.py | 82 +++++++++++++++++++ tests/python/relay/test_external_codegen.py | 27 ++++++ 6 files changed, 183 insertions(+), 5 deletions(-) create mode 100644 tests/python/contrib/test_ethosn/test_constant_duplication.py diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 935ac16efb23..c07d32bc3724 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -290,8 +290,36 @@ runtime::Module CCompiler(const ObjectRef& ref) { return csource.CreateCSourceModule(ref); } +/*! + * \brief A visitor to add the constants used as params for MetadataModule. + */ +struct CCompilerConstantUpdater : public ExprVisitor { + public: + CCompilerConstantUpdater() = default; + + Map GetConstants(const Expr& expr) { + VisitExpr(expr); + return this->params_; + } + + void VisitExpr_(const ConstantNode* cn) final { + std::string name = "ccompiler_p" + std::to_string(const_idx_++); + params_.Set(name, cn->data); + } + + private: + int const_idx_{0}; + Map params_; +}; + +Map GetConstants(const Expr& expr) { + return CCompilerConstantUpdater().GetConstants(expr); +} + TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); +TVM_REGISTER_GLOBAL("relay.ext.ccompiler.constant_updater").set_body_typed(GetConstants); + } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index f3d7f4562533..c8230a08d5ae 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -338,6 +338,10 @@ runtime::Module CompileEthosn(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.ethos-n").set_body_typed(CompileEthosn); +TVM_REGISTER_GLOBAL("relay.ext.ethos-n.constant_updater").set_body_typed([](Expr expr) { + return Map(); +}); + } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7b71e34b777b..d14c08cdb80d 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -371,10 +371,28 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler); + ICHECK(codegen.defined()) << "No external codegen is set"; + std::string codegen_name = codegen.value(); const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); - std::string symobl = std::string(name_node.value()); - ConstantUpdater const_visit(symobl, ¶ms_); - const_visit(func); + std::string symbol = std::string(name_node.value()); + std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; + // Get the constant updater for the external codegen + auto pf = tvm::runtime::Registry::Get(const_update_name); + // If the backend hasn't registered a constant updater, use a default one + if (pf == nullptr) { + ConstantUpdater const_visit(symbol, ¶ms_); + const_visit(func); + } else { + Map constants = (*pf)(func); + for (const auto& it : constants) { + std::string const_name(it.first); + // Constant names should begin this the compiler name (to avoid conflicts) + ICHECK(const_name.find(codegen_name) == 0) + << "External constant names must start with compiler name"; + params_[const_name] = it.second; + } + } return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 4a7e5eec17bc..7e477b543638 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1124,8 +1124,27 @@ void VMCompiler::Codegen() { if (target_str == "ext_dev") { // Collect metadata in functions that are handled by external codegen. ICHECK(mod->ContainGlobalVar(cfunc->func_name)); - backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); - const_visit(Downcast(mod->Lookup(cfunc->func_name))); + Function func = Downcast(mod->Lookup(cfunc->func_name)); + auto codegen = func->GetAttr(attr::kCompiler); + ICHECK(codegen.defined()) << "No external codegen is set"; + std::string codegen_name = codegen.value(); + std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; + // Get the constant updater for the external codegen + auto pf = tvm::runtime::Registry::Get(const_update_name); + // If the backend hasn't registered a constant updater, use a default one + if (pf == nullptr) { + backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); + const_visit(func); + } else { + Map constants = (*pf)(func); + for (const auto& it : constants) { + std::string const_name(it.first); + // Constant names should begin this the compiler name (to avoid conflicts) + ICHECK(const_name.find(codegen_name) == 0) + << "External constant names must start with compiler name"; + params_[const_name] = it.second; + } + } continue; } else if (funcs.count(target_str) == 0) { funcs.emplace(target_str, mod); diff --git a/tests/python/contrib/test_ethosn/test_constant_duplication.py b/tests/python/contrib/test_ethosn/test_constant_duplication.py new file mode 100644 index 000000000000..a096e57c19a9 --- /dev/null +++ b/tests/python/contrib/test_ethosn/test_constant_duplication.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test that constants aren't duplicated for Ethos-N""" + +import numpy as np +import tvm +from tvm import relay +from tvm.relay.op.contrib.ethosn import ethosn_available +from . import infrastructure as tei + + +def _get_model(): + """Return a model and any parameters it may have""" + shape = (1, 4, 4, 4) + kernel_h = 3 + kernel_w = 3 + out_channels = 8 + + a = relay.var("a", shape=shape, dtype="uint8") + add_const_value = tvm.nd.array(np.random.randint(0, high=10, size=shape, dtype="uint8")) + add_const = relay.const(add_const_value, "uint8") + a = relay.add(a, add_const) + weight_shape = (kernel_h, kernel_w, shape[3], out_channels) + w = tvm.nd.array(np.random.randint(low=0, high=255, size=weight_shape, dtype="uint8")) + weights = relay.const(w, "uint8") + conv = relay.qnn.op.conv2d( + a, + weights, + input_zero_point=relay.const(0, "int32"), + kernel_zero_point=relay.const(0, "int32"), + input_scale=relay.const(0.3, "float32"), + kernel_scale=relay.const(0.4, "float32"), + kernel_size=(kernel_h, kernel_w), + data_layout="NHWC", + kernel_layout="HWIO", + dilation=(1, 1), + strides=(1, 1), + groups=1, + channels=out_channels, + padding=(0, 0, 0, 0), + out_dtype="int32", + ) + b = tvm.nd.array(np.random.randint(0, high=10, size=(out_channels,), dtype="int32")) + biasc = relay.const(b, "int32") + bias = relay.nn.bias_add(conv, biasc, axis=3) + req = relay.qnn.op.requantize( + bias, + relay.const(0.3 * 0.4, "float32"), # input zero scale + relay.const(0, "int32"), # input zero point + relay.const(0.4, "float32"), # output zero scale + relay.const(0, "int32"), # output zero point + out_dtype="uint8", + ) + params = {"w": w, "b": b} + return req, params + + +def test_constant_duplication(): + if not ethosn_available(): + return + + model, params = _get_model() + mod = tei.make_module(model, params) + res = tei.build(mod, params, npu=True, expected_host_ops=1) + for key, value in res.params.items(): + assert key == "p0" + assert value.asnumpy().size == 64 diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c919e7ce1a7c..24019101334f 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -219,6 +219,32 @@ def test_extern_gcc(): check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) +def test_extern_gcc_consts(): + x = relay.var("x", shape=(8, 8)) + y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + + x0 = relay.var("x0", shape=(8, 8)) + y0_const = relay.const(y0_data, "float32") + z = x0 + y0_const + f = relay.Function([x0], z) + f = set_external_func_attr(f, "ccompiler", "ccompiler_0") + call = relay.Call(f, [x]) + mod = tvm.IRModule.from_expr(call) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + compiler = relay.backend.vm.VMCompiler() + compiler.lower(mod, "llvm") + compiler.codegen() + params = compiler.get_params() + assert len(params) == 1 + assert "ccompiler_p0" in params.keys() + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + _, _, params = relay.build(mod, target="llvm") + assert len(params) == 1 + assert "ccompiler_p0" in params.keys() + + def test_extern_dnnl(): if not tvm.get_global_func("relay.ext.dnnl", True): print("skip because DNNL codegen is not available") @@ -301,5 +327,6 @@ def test_extern_dnnl_const(): test_extern_gcc_single_op() test_extern_gcc_single_op_int() test_extern_gcc() + test_extern_gcc_consts() test_extern_dnnl() test_extern_dnnl_const() From d35e6683050f456737dfa4197b227a7c1302e256 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 16 Oct 2020 15:28:46 +0100 Subject: [PATCH 2/6] Fix to use symbol in const name Change-Id: I0ade81af9002d413c5b20a50488018e8cd8d8bad --- src/relay/backend/contrib/codegen_c/codegen.cc | 9 +++++---- src/relay/backend/contrib/codegen_c/codegen_c.h | 2 +- src/relay/backend/contrib/ethosn/codegen_ethosn.h | 5 ++--- src/relay/backend/graph_runtime_codegen.cc | 2 +- src/relay/backend/vm/compiler.cc | 2 +- tests/python/relay/test_external_codegen.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index c07d32bc3724..07d5330de05c 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -295,7 +295,7 @@ runtime::Module CCompiler(const ObjectRef& ref) { */ struct CCompilerConstantUpdater : public ExprVisitor { public: - CCompilerConstantUpdater() = default; + explicit CCompilerConstantUpdater(const std::string& symbol) : symbol_(symbol) {}; Map GetConstants(const Expr& expr) { VisitExpr(expr); @@ -303,17 +303,18 @@ struct CCompilerConstantUpdater : public ExprVisitor { } void VisitExpr_(const ConstantNode* cn) final { - std::string name = "ccompiler_p" + std::to_string(const_idx_++); + std::string name = symbol_ + "_p" + std::to_string(const_idx_++); params_.Set(name, cn->data); } private: int const_idx_{0}; + std::string symbol_; Map params_; }; -Map GetConstants(const Expr& expr) { - return CCompilerConstantUpdater().GetConstants(expr); +Map GetConstants(const Expr& expr, const std::string symbol) { + return CCompilerConstantUpdater(symbol).GetConstants(expr); } TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 9448b4d0738d..759831c5f184 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -334,7 +334,7 @@ class CodegenCBase { * \return The created variable name */ std::string CreateConstVar(const std::string& symbol, int const_id) const { - return symbol + "_const_" + std::to_string(const_id++); + return symbol + "_p" + std::to_string(const_id++); } /*! \brief The external function source code stream. */ diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index c8230a08d5ae..4b3e1bc05367 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -338,9 +338,8 @@ runtime::Module CompileEthosn(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.ethos-n").set_body_typed(CompileEthosn); -TVM_REGISTER_GLOBAL("relay.ext.ethos-n.constant_updater").set_body_typed([](Expr expr) { - return Map(); -}); +TVM_REGISTER_GLOBAL("relay.ext.ethos-n.constant_updater") + .set_body_typed([](Expr expr, std::string symbol) { return Map(); }); } // namespace ethosn } // namespace contrib diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index d14c08cdb80d..655fdf013d1e 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -384,7 +384,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator constants = (*pf)(func); + Map constants = (*pf)(func, symbol); for (const auto& it : constants) { std::string const_name(it.first); // Constant names should begin this the compiler name (to avoid conflicts) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7e477b543638..8b8a38f00e84 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1136,7 +1136,7 @@ void VMCompiler::Codegen() { backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); const_visit(func); } else { - Map constants = (*pf)(func); + Map constants = (*pf)(func, cfunc->func_name); for (const auto& it : constants) { std::string const_name(it.first); // Constant names should begin this the compiler name (to avoid conflicts) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 24019101334f..c602c87e93a4 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -237,12 +237,12 @@ def test_extern_gcc_consts(): compiler.codegen() params = compiler.get_params() assert len(params) == 1 - assert "ccompiler_p0" in params.keys() + assert "ccompiler_0_p0" in params.keys() with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): _, _, params = relay.build(mod, target="llvm") assert len(params) == 1 - assert "ccompiler_p0" in params.keys() + assert "ccompiler_0_p0" in params.keys() def test_extern_dnnl(): From 3878667f667c9ab135d61501839251b097aa934c Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 16 Oct 2020 16:18:52 +0100 Subject: [PATCH 3/6] Remove ; Change-Id: I61967bc4997efb87f87b49dad7e0a660c536ef35 --- src/relay/backend/contrib/codegen_c/codegen.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 07d5330de05c..ea1ba50f7645 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -295,7 +295,7 @@ runtime::Module CCompiler(const ObjectRef& ref) { */ struct CCompilerConstantUpdater : public ExprVisitor { public: - explicit CCompilerConstantUpdater(const std::string& symbol) : symbol_(symbol) {}; + explicit CCompilerConstantUpdater(const std::string& symbol) : symbol_(symbol) {} Map GetConstants(const Expr& expr) { VisitExpr(expr); From f5924b171244702f6836069c893ca7538fa92d54 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 26 Oct 2020 14:56:24 +0000 Subject: [PATCH 4/6] Remove ccompiler constant updater Change-Id: Iea9ee0f689683512fa114afeadeccb7fc9048e4f --- .../backend/contrib/codegen_c/codegen.cc | 29 ------------------- .../backend/contrib/codegen_c/codegen_c.h | 2 +- tests/python/relay/test_external_codegen.py | 5 ++++ 3 files changed, 6 insertions(+), 30 deletions(-) diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index ea1ba50f7645..935ac16efb23 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -290,37 +290,8 @@ runtime::Module CCompiler(const ObjectRef& ref) { return csource.CreateCSourceModule(ref); } -/*! - * \brief A visitor to add the constants used as params for MetadataModule. - */ -struct CCompilerConstantUpdater : public ExprVisitor { - public: - explicit CCompilerConstantUpdater(const std::string& symbol) : symbol_(symbol) {} - - Map GetConstants(const Expr& expr) { - VisitExpr(expr); - return this->params_; - } - - void VisitExpr_(const ConstantNode* cn) final { - std::string name = symbol_ + "_p" + std::to_string(const_idx_++); - params_.Set(name, cn->data); - } - - private: - int const_idx_{0}; - std::string symbol_; - Map params_; -}; - -Map GetConstants(const Expr& expr, const std::string symbol) { - return CCompilerConstantUpdater(symbol).GetConstants(expr); -} - TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); -TVM_REGISTER_GLOBAL("relay.ext.ccompiler.constant_updater").set_body_typed(GetConstants); - } // namespace contrib } // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 759831c5f184..9448b4d0738d 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -334,7 +334,7 @@ class CodegenCBase { * \return The created variable name */ std::string CreateConstVar(const std::string& symbol, int const_id) const { - return symbol + "_p" + std::to_string(const_id++); + return symbol + "_const_" + std::to_string(const_id++); } /*! \brief The external function source code stream. */ diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c602c87e93a4..72b47c919539 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -220,6 +220,11 @@ def test_extern_gcc(): def test_extern_gcc_consts(): + @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") + def constant_updater(expr, symbol): + """A dummy constant updater just to test that a custom one works.""" + return {"ccompiler_0_p0": tvm.nd.array(y0_data)} + x = relay.var("x", shape=(8, 8)) y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32") From a7871661495650245aec099022bd5e68e9046d01 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 29 Oct 2020 10:02:53 +0000 Subject: [PATCH 5/6] Unregister updater after test Change-Id: I8009940bb2ac949f2c3f0d72c943a5b74afd6954 --- tests/python/relay/test_external_codegen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 72b47c919539..7e00cb717c92 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -249,6 +249,8 @@ def constant_updater(expr, symbol): assert len(params) == 1 assert "ccompiler_0_p0" in params.keys() + tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater") + def test_extern_dnnl(): if not tvm.get_global_func("relay.ext.dnnl", True): From 0c43201428a93473f57fa42dc01d9395bbfee37b Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 29 Oct 2020 16:06:53 +0000 Subject: [PATCH 6/6] Create UpdateConstants utility function Change-Id: I83c8c6f92cfe3be3a7e811e98a4eec17590186ff --- src/relay/backend/graph_runtime_codegen.cc | 27 +------------------ src/relay/backend/utils.h | 31 ++++++++++++++++++++++ src/relay/backend/vm/compiler.cc | 21 +-------------- 3 files changed, 33 insertions(+), 46 deletions(-) diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 655fdf013d1e..e24d18de931c 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -368,32 +368,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler); - ICHECK(codegen.defined()) << "No external codegen is set"; - std::string codegen_name = codegen.value(); - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); - std::string symbol = std::string(name_node.value()); - std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; - // Get the constant updater for the external codegen - auto pf = tvm::runtime::Registry::Get(const_update_name); - // If the backend hasn't registered a constant updater, use a default one - if (pf == nullptr) { - ConstantUpdater const_visit(symbol, ¶ms_); - const_visit(func); - } else { - Map constants = (*pf)(func, symbol); - for (const auto& it : constants) { - std::string const_name(it.first); - // Constant names should begin this the compiler name (to avoid conflicts) - ICHECK(const_name.find(codegen_name) == 0) - << "External constant names must start with compiler name"; - params_[const_name] = it.second; - } - } - + UpdateConstants(func, ¶ms_); return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 3def6359c615..4426642e8e18 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -63,6 +63,37 @@ struct ConstantUpdater : public ExprVisitor { std::unordered_map* params_; }; +/*! + * \brief A function to update the params with constants found in an external function. + * \param func The function from which to get the constant params. + * \param params The params to update with the constants. + */ +inline void UpdateConstants(Function func, + std::unordered_map* params) { + auto codegen = func->GetAttr(attr::kCompiler); + ICHECK(codegen.defined()) << "No external codegen is set"; + std::string codegen_name = codegen.value(); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + std::string symbol = std::string(name_node.value()); + std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; + // Get the constant updater for the external codegen + auto pf = tvm::runtime::Registry::Get(const_update_name); + // If the backend hasn't registered a constant updater, use a default one + if (pf == nullptr) { + ConstantUpdater const_visit(symbol, params); + const_visit(func); + } else { + Map constants = (*pf)(func, symbol); + for (const auto& it : constants) { + std::string const_name(it.first); + // Constant names should begin this the compiler name (to avoid conflicts) + ICHECK(const_name.find(codegen_name) == 0) + << "External constant names must start with compiler name"; + (*params)[const_name] = it.second; + } + } +} + /*! * \brief A simple wrapper around ExprFunctor for a single argument case. * The result of visit is memoized. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 8b8a38f00e84..f652644afa3c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1125,26 +1125,7 @@ void VMCompiler::Codegen() { // Collect metadata in functions that are handled by external codegen. ICHECK(mod->ContainGlobalVar(cfunc->func_name)); Function func = Downcast(mod->Lookup(cfunc->func_name)); - auto codegen = func->GetAttr(attr::kCompiler); - ICHECK(codegen.defined()) << "No external codegen is set"; - std::string codegen_name = codegen.value(); - std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; - // Get the constant updater for the external codegen - auto pf = tvm::runtime::Registry::Get(const_update_name); - // If the backend hasn't registered a constant updater, use a default one - if (pf == nullptr) { - backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); - const_visit(func); - } else { - Map constants = (*pf)(func, cfunc->func_name); - for (const auto& it : constants) { - std::string const_name(it.first); - // Constant names should begin this the compiler name (to avoid conflicts) - ICHECK(const_name.find(codegen_name) == 0) - << "External constant names must start with compiler name"; - params_[const_name] = it.second; - } - } + backend::UpdateConstants(func, ¶ms_); continue; } else if (funcs.count(target_str) == 0) { funcs.emplace(target_str, mod);