From d66ee7429d6eae2414083a4dce9667b0e5e42559 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Thu, 30 Sep 2021 17:01:41 +0000 Subject: [PATCH] [CMSIS-NN] Initial operator support for Add This patch aims to add initial support for the `Add` operator to CMSIS NN, which was actually similar enough to the `Mul` operator that it shares quite a bit of code - exciting times. --- python/tvm/relay/op/contrib/cmsisnn.py | 19 ++- .../backend/contrib/cmsisnn/relay_to_tir.cc | 130 ++++++++++++++---- .../{test_mul.py => test_binary_ops.py} | 27 ++-- 3 files changed, 134 insertions(+), 42 deletions(-) rename tests/python/contrib/test_cmsisnn/{test_mul.py => test_binary_ops.py} (87%) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index c28e97b0e9d3..db584fb2eb71 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -79,9 +79,9 @@ def check_quantized_softmax(extract): and dequantize_call.args[0].checked_type.dtype == "int8" ) - def mul_pattern(): - """Matcher for QNN multiplication""" - return is_op("qnn.mul")( + def binary_op_pattern(op): + """Matches QNN binary operation""" + return is_op(f"qnn.{op}")( wildcard(), wildcard(), is_constant(), @@ -92,7 +92,7 @@ def mul_pattern(): is_constant(), ) - def check_quantized_mul(extract): + def check_quantized_binary_op(extract): """Check if multiply is supported by CMSIS-NN.""" return ( extract.args[0].checked_type.dtype == "int8" @@ -101,5 +101,14 @@ def check_quantized_mul(extract): return [ ("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax), - ("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul), + ( + "cmsisnn.quantized_mul", + binary_op_pattern("mul"), + check_quantized_binary_op, + ), + ( + "cmsisnn.quantized_add", + binary_op_pattern("add"), + check_quantized_binary_op, + ), ] diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index bcb171ca25f8..3c3346340f04 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -26,6 +26,7 @@ #include #include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" namespace tvm { namespace relay { @@ -39,11 +40,7 @@ class RelayToTIRVisitor : public MixedModeVisitor { tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; } private: - template - const T ArgumentToConstantValue(const Expr& arg) { - const ConstantNode* constant_node = arg.as(); - return static_cast(constant_node->data->data)[0]; - } + inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); } void CreatePrimFuncForExtern(Array func_signature, tvm::Array call_extern_args) { @@ -62,7 +59,7 @@ class RelayToTIRVisitor : public MixedModeVisitor { auto* quantize_call = expr.as(); auto* softmax_call = quantize_call->args[0].as(); auto* dequant_call = softmax_call->args[0].as(); - const float quant_scale = ArgumentToConstantValue(dequant_call->args[1]); + const float quant_scale = GetScalarFromConstant(dequant_call->args[1]); // assuming layout as NHWC auto shape = quantize_call->type_as()->shape; @@ -95,10 +92,15 @@ class RelayToTIRVisitor : public MixedModeVisitor { Array func_signature{in_var, out_var}; tvm::Array args = { - tir::StringImm("arm_softmax_s8"), in_var, - IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size), - IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift), - IntImm(DataType::Int(32), diff_min), out_var}; + tir::StringImm("arm_softmax_s8"), + in_var, + ToArg(num_rows), + ToArg(row_size), + ToArg(mult), + ToArg(shift), + ToArg(diff_min), + out_var, + }; CreatePrimFuncForExtern(func_signature, args); } @@ -106,12 +108,12 @@ class RelayToTIRVisitor : public MixedModeVisitor { void EmitMul(const Expr& expr) { auto* mul_call = expr.as(); - const float input_0_scale = ArgumentToConstantValue(mul_call->args[2]); - const int32_t input_0_zero_point = ArgumentToConstantValue(mul_call->args[3]); - const float input_1_scale = ArgumentToConstantValue(mul_call->args[4]); - const int32_t input_1_zero_point = ArgumentToConstantValue(mul_call->args[5]); - const float output_scale = ArgumentToConstantValue(mul_call->args[6]); - const int32_t output_zero_point = ArgumentToConstantValue(mul_call->args[7]); + const float input_0_scale = GetScalarFromConstant(mul_call->args[2]); + const int32_t input_0_zero_point = GetScalarFromConstant(mul_call->args[3]); + const float input_1_scale = GetScalarFromConstant(mul_call->args[4]); + const int32_t input_1_zero_point = GetScalarFromConstant(mul_call->args[5]); + const float output_scale = GetScalarFromConstant(mul_call->args[6]); + const int32_t output_zero_point = GetScalarFromConstant(mul_call->args[7]); double quantized_multiplier = static_cast(input_0_scale) * static_cast(input_1_scale) / @@ -132,14 +134,81 @@ class RelayToTIRVisitor : public MixedModeVisitor { tir::StringImm("arm_elementwise_mul_s8"), input_0, input_1, - IntImm(DataType::Int(32), -input_0_zero_point), - IntImm(DataType::Int(32), -input_1_zero_point), + ToArg(-input_0_zero_point), + ToArg(-input_1_zero_point), output, - IntImm(DataType::Int(32), output_zero_point), - IntImm(DataType::Int(32), output_multiplier), - IntImm(DataType::Int(32), output_shift), - IntImm(DataType::Int(32), std::numeric_limits::min()), - IntImm(DataType::Int(32), std::numeric_limits::max()), + ToArg(output_zero_point), + ToArg(output_multiplier), + ToArg(output_shift), + ToArg(std::numeric_limits::min()), + ToArg(std::numeric_limits::max()), + tensor_size, + }; + + CreatePrimFuncForExtern(func_signature, args); + } + + void EmitAdd(const Expr& expr) { + auto* add_call = expr.as(); + + const float input_0_scale = GetScalarFromConstant(add_call->args[2]); + const int32_t input_0_zero_point = GetScalarFromConstant(add_call->args[3]); + const float input_1_scale = GetScalarFromConstant(add_call->args[4]); + const int32_t input_1_zero_point = GetScalarFromConstant(add_call->args[5]); + const float output_scale = GetScalarFromConstant(add_call->args[6]); + const int32_t output_zero_point = GetScalarFromConstant(add_call->args[7]); + + const int32_t left_shift = 20; + const int32_t input_0_offset = -input_0_zero_point; + const int32_t input_1_offset = -input_1_zero_point; + + const float max_input_scale = std::max(input_0_scale, input_1_scale); + const double twice_max_input_scale = 2 * static_cast(max_input_scale); + const double scaled_input_0_scale = static_cast(input_0_scale) / twice_max_input_scale; + const double scaled_input_1_scale = static_cast(input_1_scale) / twice_max_input_scale; + const double scaled_output_scale = + twice_max_input_scale / ((1 << left_shift) * static_cast(output_scale)); + + auto input_0_mult_shift_pair = + tvm::relay::qnn::GetFixedPointMultiplierShift(scaled_input_0_scale); + int32_t input_0_multiplier = std::get<0>(input_0_mult_shift_pair); + int32_t input_0_shift = std::get<1>(input_0_mult_shift_pair); + + auto input_1_mult_shift_pair = + tvm::relay::qnn::GetFixedPointMultiplierShift(scaled_input_1_scale); + int32_t input_1_multiplier = std::get<0>(input_1_mult_shift_pair); + int32_t input_1_shift = std::get<1>(input_1_mult_shift_pair); + + auto output_mult_shift_pair = + tvm::relay::qnn::GetFixedPointMultiplierShift(scaled_output_scale); + int32_t output_multiplier = std::get<0>(output_mult_shift_pair); + int32_t output_shift = std::get<1>(output_mult_shift_pair); + + PrimExpr tensor_size = add_call->type_as()->Size(); + + tir::Var input_0("input_0", DataType::Handle(8)); + tir::Var input_1("input_1", DataType::Handle(8)); + tir::Var output("output", DataType::Handle(8)); + + Array func_signature{input_0, input_1, output}; + + tvm::Array args = { + tir::StringImm("arm_elementwise_add_s8"), + input_0, + input_1, + ToArg(input_0_offset), + ToArg(input_0_multiplier), + ToArg(input_0_shift), + ToArg(input_1_offset), + ToArg(input_1_multiplier), + ToArg(input_1_shift), + ToArg(left_shift), + output, + ToArg(output_zero_point), + ToArg(output_multiplier), + ToArg(output_shift), + ToArg(std::numeric_limits::min()), + ToArg(std::numeric_limits::max()), tensor_size, }; @@ -153,11 +222,16 @@ class RelayToTIRVisitor : public MixedModeVisitor { } auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") { - EmitSoftMax(func->body); - } - if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") { - EmitMul(func->body); + if (comp_name.defined()) { + if (comp_name == "cmsisnn.quantized_softmax") { + EmitSoftMax(func->body); + } + if (comp_name == "cmsisnn.quantized_mul") { + EmitMul(func->body); + } + if (comp_name == "cmsisnn.quantized_add") { + EmitAdd(func->body); + } } } diff --git a/tests/python/contrib/test_cmsisnn/test_mul.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py similarity index 87% rename from tests/python/contrib/test_cmsisnn/test_mul.py rename to tests/python/contrib/test_cmsisnn/test_binary_ops.py index 88fbeb2dfcfe..72e47e50b878 100644 --- a/tests/python/contrib/test_cmsisnn/test_mul.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""CMSIS-NN integration tests: mul""" +"""CMSIS-NN integration tests: binary ops""" import sys @@ -35,6 +35,7 @@ def make_model( + op, shape, input_0_dtype, input_1_dtype, @@ -47,7 +48,7 @@ def make_model( ): """Create a Relay Function / network model""" - return relay.qnn.op.mul( + return op( relay.var("input_0", shape=shape, dtype=input_0_dtype), relay.var("input_1", shape=shape, dtype=input_1_dtype), relay.const(input_0_scale, "float32"), @@ -60,19 +61,17 @@ def make_model( @skip_if_no_reference_system +@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) @pytest.mark.parametrize( [ "input_0_scale", "input_0_zero_point", "input_1_scale", "input_1_zero_point", - "output_tolerance", ], - [[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]], + [[0.256, 33, 0.256, 33], [0.0128, -64, 0.0128, -64], [0.0128, -64, 0.256, 33]], ) -def test_mul_int8( - input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance -): +def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point): interface_api = "c" use_unpacked_api = True test_runner = AOT_CORSTONE300_RUNNER @@ -80,7 +79,14 @@ def test_mul_int8( dtype = "int8" shape = [1, 16, 16, 3] model = make_model( - shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point + op, + shape, + dtype, + dtype, + input_0_scale, + input_0_zero_point, + input_1_scale, + input_1_zero_point, ) orig_mod = make_module(model) @@ -115,7 +121,7 @@ def test_mul_int8( module=cmsisnn_mod, inputs=inputs, outputs=output_list, - output_tolerance=output_tolerance, + output_tolerance=1, ), test_runner, interface_api, @@ -123,13 +129,16 @@ def test_mul_int8( ) +@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) @pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]]) def test_invalid_parameters( + op, input_dtype, ): input_scale = 0.256 input_zero_point = 33 model = make_model( + op, [1, 16, 16, 3], input_dtype, input_dtype,