diff --git a/python/tvm/relay/op/contrib/_ethosn.py b/python/tvm/relay/op/contrib/_ethosn.py index ea2915675ec6..9c7c922fdfb0 100644 --- a/python/tvm/relay/op/contrib/_ethosn.py +++ b/python/tvm/relay/op/contrib/_ethosn.py @@ -20,3 +20,4 @@ import tvm._ffi tvm._ffi._init_api("relay.ethos-n.support", __name__) +tvm._ffi._init_api("relay.backend.contrib.ethos-n", __name__) diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 469939ecf0b8..73dd6b735775 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -25,7 +25,7 @@ from tvm.relay.build_module import bind_params_by_name from ...dataflow_pattern import is_constant, is_op, wildcard -from . import _ethosn as support +from . import _ethosn from .register import register_pattern_table @@ -60,6 +60,18 @@ def ethosn_api_version() -> str: return tvm.get_global_func("relay.ethos-n.api.version")() +def ConvertEquivalents() -> tvm.ir.IRModule: # pylint: disable=invalid-name + """Converts operations into a numerically equivalent form + that can be understood by the NPU codegen. + + Return + ------ + Pass + The module pass. + """ + return _ethosn.ConvertEquivalents() + + def partition_for_ethosn(mod, params=None, **opts): """Partition the graph greedily offloading supported operators to Arm Ethos-N NPU. @@ -107,9 +119,9 @@ def partition_for_ethosn(mod, params=None, **opts): transform.AnnotateTarget("ethos-n"), transform.MergeCompilerRegions(), transform.PartitionGraph(), + ConvertEquivalents(), ] ) - return seq(mod) @@ -183,70 +195,102 @@ def qnn_resize_pattern(): ) return pattern + def qnn_mul_pattern(): + """ + Multiply is supported when one input is a constant of shape [1, ..., C], + where C matches the number of channels of the other input. + """ + mul_op = is_op("qnn.mul") + gen_mul_inputs = lambda x, y: mul_op( + x, + y, + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + input_is_left = gen_mul_inputs(wildcard(), is_constant()) + input_is_right = gen_mul_inputs(is_constant(), wildcard()) + return input_is_left | input_is_right + def check_conv2d(extract): """Check if a conv2d is supported by Ethos-N.""" if not ethosn_available(): return False - return support.conv2d(extract) + return _ethosn.conv2d(extract) def check_fc(extract): """Check if a fully connected is supported by Ethos-N.""" if not ethosn_available(): return False - return support.fc(extract) + return _ethosn.fc(extract) def check_avg_pool2d(extract): """Check if a avg pool2d is supported by Ethos-N.""" if not ethosn_available(): return False - return support.avg_pool2d(extract) + return _ethosn.avg_pool2d(extract) def check_mean(extract): """Check if mean is supported by Ethos-N.""" if not ethosn_available(): return False - return support.mean(extract) + return _ethosn.mean(extract) def check_sigmoid(extract): """Check if a sigmoid is supported by Ethos-N.""" if not ethosn_available(): return False - return support.sigmoid(extract) + return _ethosn.sigmoid(extract) def check_tanh(extract): """Check if tanh is supported by Ethos-N.""" if not ethosn_available(): return False - return support.tanh(extract) + return _ethosn.tanh(extract) def check_leaky_relu(extract): """Check if Leaky ReLU is supported.""" if not ethosn_available(): return False - return support.leaky_relu(extract) + return _ethosn.leaky_relu(extract) + + def check_mul(extract): + """Check if Mul is supported.""" + if not ethosn_available(): + return False + # Do not support scalar constants for now + check_scalar = lambda i: isinstance(i, tvm.relay.Constant) and len(i.data.shape) == 0 + if check_scalar(extract.args[0]) or check_scalar(extract.args[1]): + return False + extract = _ethosn.ConvertQnnMultiply(extract) + return _ethosn.conv2d(extract) def check_requantize(extract): """Check if requantize is supported.""" if not ethosn_available(): return False - return support.requantize(extract) + return _ethosn.requantize(extract) def check_resize(extract): """Check if resize (nearest neighbor) is supported.""" if not ethosn_available(): return False - return support.resize(extract) + return _ethosn.resize(extract) return [ + ("ethos-n.qnn_mul", qnn_mul_pattern(), check_mul), ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d), ("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d), ("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid), @@ -274,7 +318,7 @@ def max_pool2d(expr): if not ethosn_available(): return False - return support.max_pool2d(expr) + return _ethosn.max_pool2d(expr) @tvm.ir.register_op_attr("reshape", "target.ethos-n") @@ -285,7 +329,7 @@ def reshape(expr): if not _is_ethosn_composite(expr.args[0]): return False - return support.reshape(expr) + return _ethosn.reshape(expr) @tvm.ir.register_op_attr("qnn.add", "target.ethos-n") @@ -294,7 +338,7 @@ def qnn_add(expr): if not ethosn_available(): return False - return support.addition(expr) + return _ethosn.addition(expr) @tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n") @@ -302,7 +346,7 @@ def qnn_concatenate(expr): """Check if a concatenate is supported by Ethos-N.""" if not ethosn_available(): return False - if not support.concatenate(expr): + if not _ethosn.concatenate(expr): return False # Support library has some unenforced restrictions on qnn params @@ -332,7 +376,7 @@ def split(expr): return False if ethosn_api_version() >= LooseVersion("3.0.1"): return False - if not support.split(expr): + if not _ethosn.split(expr): return False return True @@ -343,7 +387,7 @@ def depth_to_space(expr): """Check if a depth_to_space is supported by Ethos-N.""" if not ethosn_available(): return False - if not support.depth_to_space(expr): + if not _ethosn.depth_to_space(expr): return False return True @@ -354,7 +398,7 @@ def clip(expr): """Check if a clip is supported by Ethos-N.""" if not ethosn_available(): return False - if not support.relu(expr): + if not _ethosn.relu(expr): return False return True diff --git a/src/relay/backend/contrib/ethosn/convert_equivalent.cc b/src/relay/backend/contrib/ethosn/convert_equivalent.cc new file mode 100644 index 000000000000..6b64467047f4 --- /dev/null +++ b/src/relay/backend/contrib/ethosn/convert_equivalent.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/ethosn/convert_equivalent.cc + * \brief Converts operations into a numerically equivalent form + * that can be understood by the NPU codegen. + */ + +#include +#include +#include + +#include + +#include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" +#include "../../../transforms/simplify_expr.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosn { + +/*! + * \brief Converts qnn.mul to mathematically equivalent + * qnn.conv2d depthwise operation. + */ +Expr ConvertQnnMultiply(const Expr& expr) { + Call call = Downcast(expr); + + Expr input1 = call->args[0]; + Expr input2 = call->args[1]; + Expr input1_scale = call->args[2]; + Expr input1_zero_point = call->args[3]; + Expr input2_scale = call->args[4]; + Expr input2_zero_point = call->args[5]; + // Reverse the inputs if the constant is first input + if (call->args[0]->IsInstance()) { + input1 = call->args[1]; + input2 = call->args[0]; + input1_scale = call->args[4]; + input1_zero_point = call->args[5]; + input2_scale = call->args[2]; + input2_zero_point = call->args[3]; + } + Expr output_scale = call->args[6]; + Expr output_zero_point = call->args[7]; + + const auto* input_constant = input2.as(); + ICHECK(input_constant) << "Expected ConstantNode but got " << input2->GetTypeKey(); + const auto* input_constant_tt = input_constant->checked_type().as(); + int channels = input_constant_tt->shape.back().as()->value; + + runtime::NDArray input_data = input_constant->data; + runtime::NDArray kernel_data_hwoi = + runtime::NDArray::Empty({1, 1, channels, 1}, input_data->dtype, input_data->device); + kernel_data_hwoi.CopyFrom(input_data); + Constant kernel = Constant(kernel_data_hwoi, input_constant->span); + + Type output_type = expr->checked_type(); + auto output_tt = output_type.as(); + ICHECK(output_tt) << "Expected TensorTypeNode but got " << output_type->GetTypeKey(); + DataType output_dtype = output_tt->dtype; + + Expr conv2d = qnn::MakeQnnConv2D( + input1, kernel, input1_zero_point, input2_zero_point, input1_scale, input2_scale, {1, 1}, + {0, 0, 0, 0}, {1, 1}, channels, channels, {1, 1}, "NHWC", "HWOI", "NHWC", DataType::Int(32)); + Constant bias_data = MakeConstantZeros(DataType::Int(32), {channels}); + Expr bias_add = MakeBiasAdd(conv2d, bias_data, 3); + Expr requantize = qnn::MakeRequantize(bias_add, input1_scale, input1_zero_point, output_scale, + output_zero_point, -1, "None", "None", output_dtype); + + return InferType(requantize); +} + +TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertQnnMultiply") + .set_body_typed(ConvertQnnMultiply); + +class ConvertEquivalentsMutator : public MixedModeMutator { + public: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + Call call = Downcast(post); + if (!call->op->IsInstance()) { + return post; + } + + Function func = Downcast(call->op); + Function new_func = Function(func); + auto composite_name = func->GetAttr(attr::kComposite); + if (composite_name == "ethos-n.qnn_mul") { + Expr new_func_body = ConvertQnnMultiply(func->body); + new_func = WithFields(func, func->params, new_func_body); + new_func = WithAttr(std::move(new_func), attr::kComposite, String("ethos-n.qnn_conv2d")); + } + + Call new_call = WithFields(call, new_func); + return Downcast(new_call); + } +}; + +tvm::transform::Pass ConvertEquivalents() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, transform::PassContext ctx) { + for (auto gv : mod->GetGlobalVars()) { + Function func = Downcast(mod->Lookup(gv)); + auto compiler_name = func->GetAttr(attr::kCompiler); + if (compiler_name.defined() && compiler_name == "ethos-n") { + auto new_body = ConvertEquivalentsMutator().VisitExpr(func->body); + if (!new_body.same_as(func->body)) { + Function new_func = WithFields(func, func->params, new_body); + mod->Update(gv, new_func); + } + } + } + return mod; + }; + return tvm::transform::CreateModulePass( + pass_func, 0, "relay.backend.contrib.ethos-n.ConvertEquivalents", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertEquivalents") + .set_body_typed(ConvertEquivalents); + +} // namespace ethosn +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index c850bf8958c9..85938a739182 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -117,6 +117,8 @@ Expr MakeShapeOf(Expr data, DataType dtype); Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode); +Expr MakeBiasAdd(Expr data, Expr bias, int axis); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_MAKE_OP_H_ diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index 18c592f2ed69..d084e4871e95 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -121,6 +121,10 @@ static inline Expr Requantize(const Expr& data, const Array& input_sh attrs.operator->(), input_shape, attrs->out_dtype); } +Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale, + Expr output_zero_point, int axis, String rounding, String compute_dtype, + DataType out_dtype); + Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Array& types, const DequantizeAttrs* attrs); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index d05d39b733d3..ffe1cc2ca2ab 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -344,6 +344,40 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s return Constant(arr); } +/*! + * \brief Create a Constant tensor of zeros. + * + * \param dtype The data type. + * \param shape The shape of the output constant tensor. + * \return A Constant. + */ +static inline Constant MakeConstantZeros(DataType dtype, std::vector shape) { + runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); + int64_t data_size = 1; + for (int64_t dim : shape) { + data_size *= dim; + } + TVM_DTYPE_DISPATCH(dtype, DType, { + for (int64_t i = 0; i < data_size; i++) { + if (dtype == DataType::Float(16)) { + // convert to float16 + // storage is uint16_t + // Similar handling as that in MakeConstantScalar + *(static_cast(arr->data) + i) = + __truncXfYf2__(static_cast(0)); + } else if (dtype == DataType::BFloat(16)) { + // convert to bfloat16 + // storage is uint16_t + *(static_cast(arr->data) + i) = + __truncXfYf2__(static_cast(0)); + } else { + *(static_cast(arr->data) + i) = 0; + } + } + }) + return Constant(arr); +} + /*! * \brief Check whether a shape is static and create corresponding Constant. Eventually this will be removed and replaced with CheckConstantShapeArrayInteger diff --git a/tests/python/contrib/test_ethosn/test_convert_equivalents.py b/tests/python/contrib/test_ethosn/test_convert_equivalents.py new file mode 100644 index 000000000000..570009422067 --- /dev/null +++ b/tests/python/contrib/test_ethosn/test_convert_equivalents.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for the convert equivalents pass.""" + +import pytest +import numpy as np + +import tvm +from tvm import relay +from tvm.testing import requires_ethosn +from tvm.relay.op.contrib.ethosn import ConvertEquivalents + +from . import infrastructure as tei + + +def _assert_structural_equal(a, b): + """Check structural equality of two Relay expressions.""" + reason = ( + "Actual and expected relay functions are not equal. " + "ConvertEquivalents is not correctly transforming the input " + "graph." + ) + assert tvm.ir.structural_equal(a, b), reason + + +def _create_npu_module(inputs, expr, composite_name, ext_func_name): + """Wraps an operator as an NPU module.""" + gen_vars = lambda prefix, vars: [ + relay.var( + prefix + var.name_hint, shape=var.type_annotation.shape, dtype=var.type_annotation.dtype + ) + for var in vars + ] + + mod = tvm.ir.IRModule() + + func = relay.Function(relay.analysis.free_vars(expr), expr) + func = func.with_attr("Composite", composite_name) + inner_vars = gen_vars("inner_", inputs) + call = relay.Call(func, inner_vars) + + func2 = relay.Function(relay.analysis.free_vars(call), call) + func2 = func2.with_attr("Compiler", "ethos-n") + func2 = func2.with_attr("global_symbol", ext_func_name) + mod[ext_func_name] = func2 + mod = relay.transform.InferType()(mod) + + outer_vars = gen_vars("outer_", inputs) + out = relay.Call(mod.get_global_var(ext_func_name), outer_vars) + mod["main"] = relay.Function(relay.analysis.free_vars(out), out) + mod = relay.transform.InferType()(mod) + return mod + + +@requires_ethosn +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +@pytest.mark.parametrize("shape,channels", [((1, 4, 4, 8), 8), ((1, 16, 12, 4), 4)]) +@pytest.mark.parametrize("reverse_inputs", [True, False]) +def test_multiply_to_depthwise(dtype, shape, channels, reverse_inputs): + """Check that multiply is correctly converted to a depthwise operation.""" + np.random.seed(0) + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + input_zp = np.random.randint(data_min, data_max) + input_sc = np.random.random() * 2 + input2_zp = np.random.randint(data_min, data_max) + input2_sc = np.random.random() * 2 + output_zp, output_sc = tei.get_conv2d_qnn_params( + dtype, input_zp, input_sc, input2_zp, input2_sc, 1, 1, shape[3] + ) + x = relay.var("x", shape=shape, dtype=dtype) + constant_shape = (1, 1, 1, channels) + y_data = np.random.randint(data_min, data_max + 1, size=constant_shape, dtype=dtype) + + def before(): + y = relay.const(y_data, dtype=dtype) + expr = relay.qnn.op.mul( + y if reverse_inputs else x, + x if reverse_inputs else y, + relay.const(input_sc, "float32"), + relay.const(input_zp, "int32"), + relay.const(input2_sc, "float32"), + relay.const(input2_zp, "int32"), + relay.const(output_sc, "float32"), + relay.const(output_zp, "int32"), + ) + return _create_npu_module([x], expr, "ethos-n.qnn_mul", "ext_func") + + def expected(): + constant_shape_hwoi = (1, 1, channels, 1) + y_data_hwoi = y_data.reshape(constant_shape_hwoi) + y_hwoi = relay.const(y_data_hwoi, dtype=dtype) + expr = relay.qnn.op.conv2d( + x, + y_hwoi, + relay.const(input2_zp if reverse_inputs else input_zp, "int32"), + relay.const(input_zp if reverse_inputs else input2_zp, "int32"), + relay.const(input2_sc if reverse_inputs else input_sc, "float32"), + relay.const(input_sc if reverse_inputs else input2_sc, "float32"), + (1, 1), + channels, + (1, 1), + (0, 0), + (1, 1), + channels, + "NHWC", + "HWOI", + "NHWC", + "int32", + ) + expr = relay.nn.bias_add(expr, relay.const(np.zeros((channels,), dtype="int32")), axis=3) + expr = relay.qnn.op.requantize( + expr, + relay.const(input2_sc if reverse_inputs else input_sc, "float32"), + relay.const(input2_zp if reverse_inputs else input_zp, "int32"), + relay.const(output_sc, "float32"), + relay.const(output_zp, "int32"), + out_dtype=dtype, + ) + return _create_npu_module([x], expr, "ethos-n.qnn_conv2d", "ext_func") + + mod = before() + mod = ConvertEquivalents()(mod) + expected_mod = expected() + _assert_structural_equal(mod["ext_func"], expected_mod["ext_func"]) diff --git a/tests/python/contrib/test_ethosn/test_multiply.py b/tests/python/contrib/test_ethosn/test_multiply.py new file mode 100644 index 000000000000..38d8516b6721 --- /dev/null +++ b/tests/python/contrib/test_ethosn/test_multiply.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Integration tests for Multiply.""" + +import pytest +import numpy as np + +import tvm +from tvm import relay +from tvm.testing import requires_ethosn + +from . import infrastructure as tei + + +def _get_model( + shape, + constant_shape, + input_zp, + input_sc, + input2_zp, + input2_sc, + output_zp, + output_sc, + dtype, + reverse_inputs=False, +): + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + + x = relay.var("x", shape=shape, dtype=dtype) + y_data = np.random.randint(data_min, data_max + 1, size=constant_shape, dtype=dtype) + y = relay.const(y_data, dtype=dtype) + + out = relay.qnn.op.mul( + y if reverse_inputs else x, + x if reverse_inputs else y, + relay.const(input_sc, "float32"), + relay.const(input_zp, "int32"), + relay.const(input2_sc, "float32"), + relay.const(input2_zp, "int32"), + relay.const(output_sc, "float32"), + relay.const(output_zp, "int32"), + ) + params = {"y": y_data} + return out, params + + +@requires_ethosn +@pytest.mark.parametrize("dtype", ["uint8", "int8"]) +@pytest.mark.parametrize( + "shape,constant_shape", [((1, 4, 4, 8), (1, 1, 1, 8)), ((1, 16, 12, 4), (4,))] +) +@pytest.mark.parametrize("reverse_inputs", [False, True]) +def test_multiply(dtype, shape, constant_shape, reverse_inputs): + """Compare Multiply output with TVM.""" + np.random.seed(0) + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + input_zp = np.random.randint(data_min, data_max) + input_sc = np.random.random() * 2 + input2_zp = np.random.randint(data_min, data_max) + input2_sc = np.random.random() * 2 + output_zp, output_sc = tei.get_conv2d_qnn_params( + dtype, input_zp, input_sc, input2_zp, input2_sc, 1, 1, shape[3] + ) + + model, params = _get_model( + shape, + constant_shape, + input_zp, + input_sc, + input2_zp, + input2_sc, + output_zp, + output_sc, + dtype, + reverse_inputs, + ) + inputs = {"x": tvm.nd.array(np.random.randint(data_min, data_max + 1, size=shape, dtype=dtype))} + outputs = [] + for npu in [False, True]: + mod = tei.make_module(model, params) + outputs.append(tei.build_and_run(mod, inputs, 1, params, npu=npu)) + + tei.verify(outputs, dtype, 1) + + +@requires_ethosn +def test_multiply_multiple_inputs_unsupported(): + """Check multiply operator with two inputs is not offloaded.""" + np.random.seed(0) + + shape = (1, 4, 5, 6) + dtype = "int8" + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + input_zp = np.random.randint(data_min, data_max) + input_sc = np.random.random() * 2 + input2_zp = np.random.randint(data_min, data_max) + input2_sc = np.random.random() * 2 + output_zp, output_sc = tei.get_conv2d_qnn_params( + dtype, input_zp, input_sc, input2_zp, input2_sc, 1, 1, shape[3] + ) + + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.var("y", shape=shape, dtype=dtype) + model = relay.qnn.op.mul( + x, + y, + relay.const(input_sc, "float32"), + relay.const(input_zp, "int32"), + relay.const(input2_sc, "float32"), + relay.const(input2_zp, "int32"), + relay.const(output_sc, "float32"), + relay.const(output_zp, "int32"), + ) + + expected_host_ops = 1 + npu_partitions = 0 + for npu in [False, True]: + mod = tei.make_module(model, {}) + tei.build( + mod, + {}, + npu=npu, + expected_host_ops=expected_host_ops, + npu_partitions=npu_partitions, + ) + + +@requires_ethosn +def test_multiply_unsupported_datatype(): + """Check multiply operator with unsupported datatype is not offloaded.""" + np.random.seed(0) + + shape = (1, 4, 5, 6) + dtype = "int16" + + iinfo = np.iinfo(dtype) + data_min = iinfo.min + data_max = iinfo.max + input_zp = np.random.randint(data_min, data_max) + input_sc = np.random.random() * 2 + input2_zp = np.random.randint(data_min, data_max) + input2_sc = np.random.random() * 2 + output_zp, output_sc = tei.get_conv2d_qnn_params( + dtype, input_zp, input_sc, input2_zp, input2_sc, 1, 1, shape[3] + ) + + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.var("y", shape=shape, dtype=dtype) + model = relay.qnn.op.mul( + x, + y, + relay.const(input_sc, "float32"), + relay.const(input_zp, "int32"), + relay.const(input2_sc, "float32"), + relay.const(input2_zp, "int32"), + relay.const(output_sc, "float32"), + relay.const(output_zp, "int32"), + ) + + expected_host_ops = 1 + npu_partitions = 0 + for npu in [False, True]: + mod = tei.make_module(model, {}) + tei.build( + mod, + {}, + npu=npu, + expected_host_ops=expected_host_ops, + npu_partitions=npu_partitions, + )