diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 2fcdaf362a22..9bc6efdad00f 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -424,7 +424,8 @@ def is_aarch64_arm(): @qnn_conv2d_legalize.register("arm_cpu") def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): - # ARM prefers the dtypes to be same. + target = tvm.target.Target.current(allow_none=False) + has_asimd = is_aarch64_arm() or "+neon" in target.mattr is_depthwise = relay.op.strategy.is_depthwise_conv2d( types[0].shape, attrs["data_layout"], @@ -432,18 +433,23 @@ def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): attrs["kernel_layout"], attrs["groups"], ) - use_int8_on_arm = (not is_depthwise) and is_aarch64_arm() and attrs["data_layout"] == "NHWC" - if use_int8_on_arm or is_fast_int8_on_arm(): - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) - return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) + use_int8_on_arm = (not is_depthwise) and attrs["data_layout"] == "NHWC" + has_dotprod = is_fast_int8_on_arm() + other_options = use_int8_on_arm or has_dotprod + if has_asimd and not other_options: + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) + # ARM prefers the dtypes to be same. + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) @qnn_dense_legalize.register("arm_cpu") def _qnn_dense_legalize_arm_cpu(attrs, inputs, types): + target = tvm.target.Target.current(allow_none=False) + has_asimd = is_aarch64_arm() or "+neon" in target.mattr + if has_asimd and not is_fast_int8_on_arm(): + return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) # ARM prefers the dtypes to be same. - if is_fast_int8_on_arm(): - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) - return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense) + return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) ##########################