diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index f3d7f4562533..4b3e1bc05367 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -338,6 +338,9 @@ runtime::Module CompileEthosn(const ObjectRef& ref) { TVM_REGISTER_GLOBAL("relay.ext.ethos-n").set_body_typed(CompileEthosn); +TVM_REGISTER_GLOBAL("relay.ext.ethos-n.constant_updater") + .set_body_typed([](Expr expr, std::string symbol) { return Map(); }); + } // namespace ethosn } // namespace contrib } // namespace relay diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7b71e34b777b..e24d18de931c 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -368,14 +368,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorGetAttr(tvm::attr::kGlobalSymbol); - std::string symobl = std::string(name_node.value()); - ConstantUpdater const_visit(symobl, ¶ms_); - const_visit(func); - + UpdateConstants(func, ¶ms_); return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 3def6359c615..4426642e8e18 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -63,6 +63,37 @@ struct ConstantUpdater : public ExprVisitor { std::unordered_map* params_; }; +/*! + * \brief A function to update the params with constants found in an external function. + * \param func The function from which to get the constant params. + * \param params The params to update with the constants. + */ +inline void UpdateConstants(Function func, + std::unordered_map* params) { + auto codegen = func->GetAttr(attr::kCompiler); + ICHECK(codegen.defined()) << "No external codegen is set"; + std::string codegen_name = codegen.value(); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + std::string symbol = std::string(name_node.value()); + std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; + // Get the constant updater for the external codegen + auto pf = tvm::runtime::Registry::Get(const_update_name); + // If the backend hasn't registered a constant updater, use a default one + if (pf == nullptr) { + ConstantUpdater const_visit(symbol, params); + const_visit(func); + } else { + Map constants = (*pf)(func, symbol); + for (const auto& it : constants) { + std::string const_name(it.first); + // Constant names should begin this the compiler name (to avoid conflicts) + ICHECK(const_name.find(codegen_name) == 0) + << "External constant names must start with compiler name"; + (*params)[const_name] = it.second; + } + } +} + /*! * \brief A simple wrapper around ExprFunctor for a single argument case. * The result of visit is memoized. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 4a7e5eec17bc..f652644afa3c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1124,8 +1124,8 @@ void VMCompiler::Codegen() { if (target_str == "ext_dev") { // Collect metadata in functions that are handled by external codegen. ICHECK(mod->ContainGlobalVar(cfunc->func_name)); - backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); - const_visit(Downcast(mod->Lookup(cfunc->func_name))); + Function func = Downcast(mod->Lookup(cfunc->func_name)); + backend::UpdateConstants(func, ¶ms_); continue; } else if (funcs.count(target_str) == 0) { funcs.emplace(target_str, mod); diff --git a/tests/python/contrib/test_ethosn/test_constant_duplication.py b/tests/python/contrib/test_ethosn/test_constant_duplication.py new file mode 100644 index 000000000000..a096e57c19a9 --- /dev/null +++ b/tests/python/contrib/test_ethosn/test_constant_duplication.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test that constants aren't duplicated for Ethos-N""" + +import numpy as np +import tvm +from tvm import relay +from tvm.relay.op.contrib.ethosn import ethosn_available +from . import infrastructure as tei + + +def _get_model(): + """Return a model and any parameters it may have""" + shape = (1, 4, 4, 4) + kernel_h = 3 + kernel_w = 3 + out_channels = 8 + + a = relay.var("a", shape=shape, dtype="uint8") + add_const_value = tvm.nd.array(np.random.randint(0, high=10, size=shape, dtype="uint8")) + add_const = relay.const(add_const_value, "uint8") + a = relay.add(a, add_const) + weight_shape = (kernel_h, kernel_w, shape[3], out_channels) + w = tvm.nd.array(np.random.randint(low=0, high=255, size=weight_shape, dtype="uint8")) + weights = relay.const(w, "uint8") + conv = relay.qnn.op.conv2d( + a, + weights, + input_zero_point=relay.const(0, "int32"), + kernel_zero_point=relay.const(0, "int32"), + input_scale=relay.const(0.3, "float32"), + kernel_scale=relay.const(0.4, "float32"), + kernel_size=(kernel_h, kernel_w), + data_layout="NHWC", + kernel_layout="HWIO", + dilation=(1, 1), + strides=(1, 1), + groups=1, + channels=out_channels, + padding=(0, 0, 0, 0), + out_dtype="int32", + ) + b = tvm.nd.array(np.random.randint(0, high=10, size=(out_channels,), dtype="int32")) + biasc = relay.const(b, "int32") + bias = relay.nn.bias_add(conv, biasc, axis=3) + req = relay.qnn.op.requantize( + bias, + relay.const(0.3 * 0.4, "float32"), # input zero scale + relay.const(0, "int32"), # input zero point + relay.const(0.4, "float32"), # output zero scale + relay.const(0, "int32"), # output zero point + out_dtype="uint8", + ) + params = {"w": w, "b": b} + return req, params + + +def test_constant_duplication(): + if not ethosn_available(): + return + + model, params = _get_model() + mod = tei.make_module(model, params) + res = tei.build(mod, params, npu=True, expected_host_ops=1) + for key, value in res.params.items(): + assert key == "p0" + assert value.asnumpy().size == 64 diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c919e7ce1a7c..7e00cb717c92 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -219,6 +219,39 @@ def test_extern_gcc(): check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) +def test_extern_gcc_consts(): + @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") + def constant_updater(expr, symbol): + """A dummy constant updater just to test that a custom one works.""" + return {"ccompiler_0_p0": tvm.nd.array(y0_data)} + + x = relay.var("x", shape=(8, 8)) + y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + + x0 = relay.var("x0", shape=(8, 8)) + y0_const = relay.const(y0_data, "float32") + z = x0 + y0_const + f = relay.Function([x0], z) + f = set_external_func_attr(f, "ccompiler", "ccompiler_0") + call = relay.Call(f, [x]) + mod = tvm.IRModule.from_expr(call) + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + compiler = relay.backend.vm.VMCompiler() + compiler.lower(mod, "llvm") + compiler.codegen() + params = compiler.get_params() + assert len(params) == 1 + assert "ccompiler_0_p0" in params.keys() + + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + _, _, params = relay.build(mod, target="llvm") + assert len(params) == 1 + assert "ccompiler_0_p0" in params.keys() + + tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater") + + def test_extern_dnnl(): if not tvm.get_global_func("relay.ext.dnnl", True): print("skip because DNNL codegen is not available") @@ -301,5 +334,6 @@ def test_extern_dnnl_const(): test_extern_gcc_single_op() test_extern_gcc_single_op_int() test_extern_gcc() + test_extern_gcc_consts() test_extern_dnnl() test_extern_dnnl_const()