diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index 5354186167a..ac330d4b015 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -56,6 +56,7 @@ set(_cortex_m_kernels__srcs ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp new file mode 100644 index 00000000000..ad14af98865 --- /dev/null +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -0,0 +1,236 @@ +/* + * Copyright 2025 Arm Limited and/or its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "cortex_m_ops_common.h" + +extern "C" { +#include "arm_nnfunctions.h" +} + +namespace cortex_m { +namespace native { + +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; + +namespace { +constexpr int64_t kConvDim = 4; + +bool validate_conv2d_arguments( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const Tensor& output, + const IntArrayRef& stride, + const IntArrayRef& padding, + const IntArrayRef& dilation, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts) { + if (input.dim() != kConvDim || weight.dim() != kConvDim || + output.dim() != kConvDim) { + ET_LOG(Error, "quantized_conv2d_out: tensors must be 4-D"); + context.fail(Error::InvalidArgument); + return false; + } + + // Check for channels_last dim_order (NHWC: 0, 2, 3, 1) + // Skip check if channels == 1, as dim_order is ambiguous in that case + constexpr executorch::aten::DimOrderType kChannelsLastDimOrder[] = { + 0, 2, 3, 1}; + executorch::aten::ArrayRef + channels_last_order(kChannelsLastDimOrder, 4); + + if (input.size(1) > 1 && input.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_conv2d_out: input must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (output.size(1) > 1 && output.dim_order() != channels_last_order) { + ET_LOG( + Error, + "quantized_conv2d_out: output must have channels_last dim_order (NHWC)"); + context.fail(Error::InvalidArgument); + return false; + } + + if (input.scalar_type() != ScalarType::Char || + output.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_conv2d_out: input and output must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (weight.scalar_type() != ScalarType::Char) { + ET_LOG(Error, "quantized_conv2d_out: weight must be int8"); + context.fail(Error::InvalidArgument); + return false; + } + + if (bias.has_value() && bias.value().scalar_type() != ScalarType::Int) { + ET_LOG(Error, "quantized_conv2d_out: bias must be int32 if provided"); + context.fail(Error::InvalidArgument); + return false; + } + + if (stride.size() != 2 || padding.size() != 2 || dilation.size() != 2) { + ET_LOG( + Error, + "quantized_conv2d_out: stride, padding, and dilation must have length 2"); + context.fail(Error::InvalidArgument); + return false; + } + + const int64_t out_channels = output.size(1); + if (requantize_multipliers.size(0) != out_channels || + requantize_shifts.size(0) != out_channels) { + ET_LOG( + Error, + "quantized_conv2d_out: per-channel params must match output channels (%zd)", + out_channels); + context.fail(Error::InvalidArgument); + return false; + } + + return true; +} +} // namespace + +Tensor& quantized_conv2d_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& weight, + const torch::executor::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const int64_t input_offset, + const int64_t output_offset, + const Tensor& requantize_multipliers, + const Tensor& requantize_shifts, + const int64_t activation_min, + const int64_t activation_max, + Tensor& out) { + if (!validate_conv2d_arguments( + context, + input, + weight, + bias, + out, + stride, + padding, + dilation, + requantize_multipliers, + requantize_shifts)) { + return out; + } + + const int32_t batch = static_cast(input.size(0)); + const int32_t input_channels = static_cast(input.size(1)); + const int32_t input_height = static_cast(input.size(2)); + const int32_t input_width = static_cast(input.size(3)); + + const int32_t kernel_output_channels = static_cast(weight.size(0)); + const int32_t kernel_height = static_cast(weight.size(1)); + const int32_t kernel_width = static_cast(weight.size(2)); + const int32_t kernel_input_channels = static_cast(weight.size(3)); + + const int32_t output_channels = static_cast(out.size(1)); + const int32_t output_height = static_cast(out.size(2)); + const int32_t output_width = static_cast(out.size(3)); + + const int32_t input_offset_val = static_cast(input_offset); + const int32_t output_offset_val = static_cast(output_offset); + const int32_t activation_min_val = static_cast(activation_min); + const int32_t activation_max_val = static_cast(activation_max); + + const cmsis_nn_dims input_dims{ + batch, input_height, input_width, input_channels}; + const cmsis_nn_dims filter_dims{ + kernel_output_channels, + kernel_height, + kernel_width, + kernel_input_channels}; + const cmsis_nn_dims output_dims{ + batch, output_height, output_width, output_channels}; + const cmsis_nn_dims bias_dims{1, 1, 1, output_channels}; + const cmsis_nn_dims upscale_dims{1, 1, 1, 1}; + + cmsis_nn_conv_params conv_params; + conv_params.input_offset = input_offset_val; + conv_params.output_offset = output_offset_val; + conv_params.stride.h = static_cast(stride[0]); + conv_params.stride.w = static_cast(stride[1]); + conv_params.padding.h = static_cast(padding[0]); + conv_params.padding.w = static_cast(padding[1]); + conv_params.dilation.h = static_cast(dilation[0]); + conv_params.dilation.w = static_cast(dilation[1]); + conv_params.activation.min = activation_min_val; + conv_params.activation.max = activation_max_val; + + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = requantize_multipliers.data_ptr(); + quant_params.shift = requantize_shifts.data_ptr(); + + const int8_t* input_data = input.const_data_ptr(); + const int8_t* weight_data = weight.const_data_ptr(); + int8_t* output_data = out.mutable_data_ptr(); + const int32_t* bias_data = + bias.has_value() ? bias.value().const_data_ptr() : nullptr; + + cmsis_nn_context cmsis_context; + cmsis_context.buf = nullptr; + cmsis_context.size = 0; + + const size_t buffer_bytes = static_cast( + arm_convolve_s8_get_buffer_size(&input_dims, &filter_dims)); + if (buffer_bytes > 0) { + auto buffer_or_error = + context.allocate_temp(buffer_bytes, alignof(int16_t)); + if (!buffer_or_error.ok()) { + if (buffer_or_error.error() != Error::NotFound) { + ET_LOG( + Error, + "quantized_conv2d_out: failed to allocate scratch buffer (%d)", + static_cast(buffer_or_error.error())); + context.fail(buffer_or_error.error()); + return out; + } + } else { + cmsis_context.buf = buffer_or_error.get(); + cmsis_context.size = buffer_bytes; + } + } + + const arm_cmsis_nn_status status = arm_convolve_wrapper_s8( + &cmsis_context, + &conv_params, + &quant_params, + &input_dims, + input_data, + &filter_dims, + weight_data, + &bias_dims, + bias_data, + &output_dims, + output_data); + + if (status != ARM_CMSIS_NN_SUCCESS) { + ET_LOG( + Error, + "quantized_conv2d_out: arm_convolve_s8 failed with status %d", + status); + context.fail(Error::Internal); + } + + return out; +} + +} // namespace native +} // namespace cortex_m diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 8ad8f2a68e7..fe175ca9783 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -6,8 +6,10 @@ # LICENSE file in the root directory of this source tree. from math import prod +from typing import Sequence import torch +import torch.nn.functional as F from executorch.backends.cortex_m.passes.passes_utils import ( requantize_cmsis, SHIFT_INT8, @@ -408,3 +410,163 @@ def transpose_meta(input: torch.Tensor, perm) -> torch.Tensor: @impl(lib, "transpose", "CompositeExplicitAutograd") def transpose_impl(input: torch.Tensor, perm) -> torch.Tensor: return input.permute(tuple(perm)).contiguous() + + +# =================================================================== +# QUANTIZED CONV2D OPERATION DEFINITION +# =================================================================== + +lib.define( + "quantized_conv2d(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max" + ") -> Tensor" +) + + +lib.define( + "quantized_conv2d.out(" + "Tensor input, " + "Tensor weight, " + "Tensor? bias, " + "int[] stride, " + "int[] padding, " + "int[] dilation, " + "int input_offset, " + "int output_offset, " + "Tensor requantize_multipliers, " + "Tensor requantize_shifts, " + "int activation_min, " + "int activation_max, " + "*, Tensor(a!) out" + ") -> Tensor(a!)" +) + + +def _compute_conv2d_output_shape( + input_shape: torch.Size, + weight_shape: torch.Size, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], +) -> torch.Size: + batch = input_shape[0] + in_height = input_shape[2] + in_width = input_shape[3] + # We store the weights in OHWI layout (out, kernel_h, kernel_w, in) + kernel_height = weight_shape[1] + kernel_width = weight_shape[2] + + stride_h, stride_w = stride + pad_h, pad_w = padding + dilation_h, dilation_w = dilation + + out_channels = weight_shape[0] + out_height = ( + in_height + 2 * pad_h - dilation_h * (kernel_height - 1) - 1 + ) // stride_h + 1 + out_width = ( + in_width + 2 * pad_w - dilation_w * (kernel_width - 1) - 1 + ) // stride_w + 1 + return torch.Size([batch, out_channels, out_height, out_width]) + + +@register_fake("cortex_m::quantized_conv2d") +def quantized_conv2d_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + stride_vals = list(stride) + padding_vals = list(padding) + dilation_vals = list(dilation) + output_shape = _compute_conv2d_output_shape( + input.shape, weight.shape, stride_vals, padding_vals, dilation_vals + ) + return torch.empty( + output_shape, + dtype=torch.int8, + device=input.device, + memory_format=torch.channels_last, + ) + + +@impl(lib, "quantized_conv2d", "CompositeExplicitAutograd") +def quantized_conv2d_impl( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + input_offset: int, + output_offset: int, + requantize_multipliers: torch.Tensor, + requantize_shifts: torch.Tensor, + activation_min: int, + activation_max: int, +) -> torch.Tensor: + if input.dim() != 4 or weight.dim() != 4: + raise RuntimeError("quantized_conv2d expects 4D input and weight tensors") + # Convert to int32 for accumulation and apply offsets + input_int32 = input.to(torch.int32) + int(input_offset) + weight_int32 = weight.to(torch.int32) + + if bias is None: + bias_int32 = torch.zeros( + weight.shape[0], dtype=torch.int32, device=input.device + ) + else: + bias_int32 = bias.to(torch.int32) + + input_channels = input.shape[1] + kernel_input_channels = weight.shape[3] + groups = input_channels // kernel_input_channels + + # Convert weights back to OIHW layout expected by torch.nn.functional.conv2d + weight_oi_hw = weight_int32.permute(0, 3, 1, 2).contiguous() + + conv_acc = F.conv2d( + input_int32, + weight_oi_hw, + bias_int32, + stride=tuple(stride), + padding=tuple(padding), + dilation=tuple(dilation), + groups=groups, + ) + + result_channels = [] + for output_channel_i in range(conv_acc.shape[1]): + result_channel = requantize_cmsis( + conv_acc[:, output_channel_i, :, :], + int(requantize_multipliers[output_channel_i]), + int(requantize_shifts[output_channel_i]), + ) + result_channels.append(result_channel) + + result = torch.stack(result_channels, dim=1) + + result += output_offset + result = torch.clamp(result, activation_min, activation_max) + + return result.to(torch.int8) diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index 30365e730da..0b0b2f5c715 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -52,3 +52,9 @@ kernels: - arg_meta: null kernel_name: cortex_m::transpose_out + +- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: cortex_m::quantized_conv2d_out diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 26456138cb2..d1bb580d871 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .quantized_linear_fusion_pass import QuantizedLinearFusionPass # noqa +from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py new file mode 100644 index 00000000000..c849b2949bf --- /dev/null +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +import torch.fx +from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot + +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + get_param_tensor, +) + +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.graph_signature import InputKind +from torch.fx.passes.infra.pass_manager import PassResult + + +class ConvertToCortexMPass(XNNPACKPass): + """ + Cortex-M backend pass for replacing supported quantized kernels with Cortex-M + accelerated kernels. + + Used for ops which require changes to input tensors which is not supported + by call_operator. + """ + + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): + """ + Computes the precomputed kernel sum term (bias optional) + a * sum_j(wij + b) + ci + + for i = (1, ..., n), where j indexes the input activations. + """ + weights_transposed = weights.T + weights_int32 = weights_transposed.to(torch.int32) + offset_weights = weights_int32 + weight_offset + kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) + kernel_sum_offset = kernel_sum * input_offset + + if bias is not None: + kernel_sum_offset += bias + + return kernel_sum_offset + + def _get_linear_replacement(self, node): + """ + Let + - yi be the output activations (y1, ... yn) + - xj be the input activations (x1, ... xm) + - wij be the weights (w11, ... wnm) + - a be the input offset + - b be the weight offset + - ci be the bias + + Then the linear operation can be written as: + yi = sum_j((xj + a) * (wij + b)) + ci + = sum_j(xj*wij + xj*b + a*wij + a*b) + ci + = sum_j(xj*wij) + sum_j(xj)*b + (a * sum_j(wij + b) + ci) + = sum_j(xj*wij) + sum_j(xj)*b + kernel_sum + + where kernel_sum is precomputed aot. + """ + input_scale = node.meta["input_qparams"][0].scale + input_zp = node.meta["input_qparams"][0].zp + weight_scale = node.meta["input_qparams"][1].scale + weight_zp = node.meta["input_qparams"][1].zp + output_scale = node.meta["output_qparams"][0].scale + output_zp = node.meta["output_qparams"][0].zp + output_min = node.meta["output_qparams"][0].qmin + output_max = node.meta["output_qparams"][0].qmax + + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + (input_scale * weight_scale) / output_scale + ) + + # TODO: Add support for configuring the backend to support other extensions. + # Kernel sum is only used in the CMSIS-NN implementation for the MVE extension, + # so this should be optional. + weights = node.args[1] + weights_tensor = get_param_tensor(self.exported_program, weights) + bias_tensor = ( + get_param_tensor(self.exported_program, node.args[2]) + if len(node.args) > 2 + else None + ) + kernel_sum_tensor = self._compute_kernel_sum( + weights_tensor, bias_tensor, -input_zp, -weight_zp + ) + with node.graph.inserting_after(weights): + kernel_sum = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_kernel_sum", + InputKind.PARAMETER, + kernel_sum_tensor, + ) + + args = ( + node.args[0], + weights, + None, + kernel_sum, + -input_zp, + -weight_zp, + output_zp, + [quantized_multiplier], + [quantized_shift], + output_max, + output_min, + ) + + return exir_ops.edge.cortex_m.quantized_linear.default, args + + def _get_convolution_replacement(self, node) -> int: + ( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = node.args + + # Extract values + input_scale = node.meta["input_qparams"][0].scale + input_zero_point = node.meta["input_qparams"][0].zp + weight_scales = node.meta["input_qparams"][1].scale + if not isinstance(weight_scales, list): + weight_scales = [weight_scales] * weight.data.shape[0] + + output_scale = node.meta["output_qparams"][0].scale + output_zero_point = node.meta["output_qparams"][0].zp + + quantized_multipliers = [] + quantized_shifts = [] + for weight_scale in weight_scales: + quantized_multiplier, quantized_shift = quantize_multiplier_aot( + input_scale * weight_scale / output_scale + ) + quantized_multipliers.append(quantized_multiplier) + quantized_shifts.append(quantized_shift) + + # Permute the weight tensor to the OHWI layout expected by CMSIS-NN. + weight_tensor = get_param_tensor(self.exported_program, weight) + weight_permuted = weight_tensor.permute(0, 2, 3, 1).contiguous( + memory_format=torch.channels_last + ) + + with node.graph.inserting_after(weight): + weight_nhwc = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_weight_nhwc", + InputKind.PARAMETER, + weight_permuted, + ) + + new_args = ( + x, + weight_nhwc, + bias, + stride, + padding, + dilation, + -input_zero_point, + output_zero_point, + torch.tensor(quantized_multipliers, dtype=torch.int32), + torch.tensor(quantized_shifts, dtype=torch.int32), + -128, + 127, + ) + return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if ( + node.meta.get("input_qparams", {}) == {} + or node.meta.get("output_qparams", {}) == {} + ): + continue + + match node.target: + case exir_ops.edge.aten.linear.default: + op, args = self._get_linear_replacement(node) + case exir_ops.edge.aten.convolution.default: + op, args = self._get_convolution_replacement(node) + case _: + continue + + with graph_module.graph.inserting_before(node): + cortex_m_op = graph_module.graph.create_node( + "call_function", + target=op, + args=args, + kwargs={}, + ) + + node.replace_all_uses_with(cortex_m_op) + graph_module.graph.erase_node(node) + + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 2b880f5ed05..948a60121b4 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -4,30 +4,34 @@ # LICENSE file in the root directory of this source tree. +import inspect + from executorch.backends.arm._passes import ( FoldAndAnnotateQParamsPass, ScalarsToAttributePass, ) from executorch.backends.cortex_m.passes import ( - QuantizedLinearFusionPass, + ConvertToCortexMPass, QuantizedOpFusionPass, ReplaceQuantNodesPass, ) from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) -from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_manager import PassManager +from executorch.exir.program._program import _transform +from torch.export import ExportedProgram -class CortexMPassManager(XNNPACKPassManager): +class CortexMPassManager(PassManager): pass_list: list[ExportPass] = [ FoldAndAnnotateQParamsPass, ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, QuantizedOpFusionPass, - QuantizedLinearFusionPass, + ConvertToCortexMPass, ] pass_list_transform_for_annotation: list[ExportPass] = [ @@ -36,10 +40,29 @@ class CortexMPassManager(XNNPACKPassManager): ] def __init__(self, exported_program, passes=None): - super().__init__(exported_program, passes or self.pass_list) + self.exported_program = exported_program + if passes is not None: + self.passes = passes + else: + self.passes = self.pass_list def transform_for_annotation(self, model): passes = self.pass_list_transform_for_annotation for p in passes: model = p().call(model).graph_module return model + + def transform(self) -> ExportedProgram: + ep = self.exported_program + for pass_ in self.passes: + signature = inspect.signature(pass_.__init__) + if "exported_program" in signature.parameters: + transform_pass = pass_(ep) + elif issubclass(pass_, ExportPass): + transform_pass = pass_() + else: + raise RuntimeError( + f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}" + ) + ep = _transform(ep, transform_pass) + return ep diff --git a/backends/cortex_m/passes/quantized_linear_fusion_pass.py b/backends/cortex_m/passes/quantized_linear_fusion_pass.py deleted file mode 100644 index f921f5ce621..00000000000 --- a/backends/cortex_m/passes/quantized_linear_fusion_pass.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import executorch.backends.cortex_m.ops.operators # noqa - -import torch -import torch.fx -from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot - -from executorch.backends.transforms.utils import ( - create_constant_placeholder, - get_param_tensor, -) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.exir.dialects._ops import ops as exir_ops -from torch.export.graph_signature import InputKind -from torch.fx.passes.infra.pass_manager import PassResult - - -class QuantizedLinearFusionPass(XNNPACKPass): - """ - Cortex-M backend pass that fuses quantized linear-like patterns. - Fuses: dequantize -> [linear/addmm/fc_ops] -> quantize - Into: cortex_m.quantized_linear.default with direct parameters. - - Note that the optimzed implementation makes use of the following rewrite: - - Let - - yi be the output activations (y1, ... yn) - - xj be the input activations (x1, ... xm) - - wij be the weights (w11, ... wnm) - - a be the input offset - - b be the weight offset - - ci be the bias - - Then the linear operation can be written as: - yi = sum_j((xj + a) * (wij + b)) + ci - = sum_j(xj*wij + xj*b + a*wij + a*b) + ci - = sum_j(xj*wij) + sum_j(xj)*b + (a * sum_j(wij + b) + ci) - = sum_j(xj*wij) + sum_j(xj)*b + kernel_sum - - where kernel_sum is precomputed aot. - """ - - def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): - """ - Computes the precomputed kernel sum term (bias optional) - a * sum_j(wij + b) + ci - - as defined above, for i = (1, ..., n) where j indexes the input activations. - """ - weights_transposed = weights.T - weights_int32 = weights_transposed.to(torch.int32) - offset_weights = weights_int32 + weight_offset - kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) - kernel_sum_offset = kernel_sum * input_offset - - if bias is not None: - kernel_sum_offset += bias - - return kernel_sum_offset - - def _get_linear_replacement(self, args, meta, node): - input_scale = meta["input_qparams"][0].scale - input_zp = meta["input_qparams"][0].zp - weight_scale = meta["input_qparams"][1].scale - weight_zp = meta["input_qparams"][1].zp - output_scale = meta["output_qparams"][0].scale - output_zp = meta["output_qparams"][0].zp - output_min = meta["output_qparams"][0].qmin - output_max = meta["output_qparams"][0].qmax - - quantized_multiplier, quantized_shift = quantize_multiplier_aot( - (input_scale * weight_scale) / output_scale - ) - - # TODO: Add support for configuring the backend to support other extensions. - # Kernel sum is only used in the CMSIS-NN implementation for the MVE extension, - # so this should be optional. - weights = args[1] - weights_tensor = get_param_tensor(self.exported_program, weights) - bias_tensor = ( - get_param_tensor(self.exported_program, args[2]) if len(args) > 2 else None - ) - kernel_sum_tensor = self._compute_kernel_sum( - weights_tensor, bias_tensor, -input_zp, -weight_zp - ) - with node.graph.inserting_after(weights): - kernel_sum = create_constant_placeholder( - self.exported_program, - node.graph, - node.name + "_kernel_sum", - InputKind.PARAMETER, - kernel_sum_tensor, - ) - - args = ( - args[0], - weights, - None, - kernel_sum, - -input_zp, - -weight_zp, - output_zp, - [quantized_multiplier], - [quantized_shift], - output_max, - output_min, - ) - - return args - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - modified = False - for node in graph_module.graph.nodes: - if node.op != "call_function": - continue - if node.target != exir_ops.edge.aten.linear.default: - continue - if ( - node.meta.get("input_qparams", {}) == {} - or node.meta.get("output_qparams", {}) == {} - ): - continue - - args = self._get_linear_replacement(node.args, node.meta, node) - with graph_module.graph.inserting_before(node): - cortex_m_linear = graph_module.graph.create_node( - "call_function", - target=exir_ops.edge.cortex_m.quantized_linear.default, - args=args, - kwargs={}, - ) - - node.replace_all_uses_with(cortex_m_linear) - graph_module.graph.erase_node(node) - - modified = True - - if modified: - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, modified) diff --git a/backends/cortex_m/quantizer/operator_configs.py b/backends/cortex_m/quantizer/operator_configs.py index 2936129819a..c6b15fb9a78 100644 --- a/backends/cortex_m/quantizer/operator_configs.py +++ b/backends/cortex_m/quantizer/operator_configs.py @@ -10,6 +10,7 @@ import torch from executorch.backends.cortex_m.quantizer.quantization_configs import ( + INT8_PER_CHANNEL_CONFIG, INT8_PER_TENSOR_CONFIG, ) from torchao.quantization.pt2e.quantizer import OperatorConfig @@ -25,6 +26,12 @@ [torch.ops.aten.linear.default, torch.ops.aten.relu.default], ] +CONV_OP_PATTERNS = [ + [torch.ops.aten.conv1d.default], + [torch.ops.aten.conv2d.default], + [torch.ops.aten.conv3d.default], +] + # ----------------- OPERATOR CONFIG PRESETS ----------------- INT8_BINARY_OPS_OPERATOR_CONFIG = OperatorConfig( INT8_PER_TENSOR_CONFIG, BINARY_OP_PATTERNS @@ -34,3 +41,8 @@ INT8_PER_TENSOR_CONFIG, LINEAR_OP_PATTERNS, ) + +INT8_CONV_OPERATOR_CONFIG = OperatorConfig( + INT8_PER_CHANNEL_CONFIG, + CONV_OP_PATTERNS, +) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index 7f43a89daad..c6600241b6d 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -5,7 +5,11 @@ import torch -from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e import ( + HistogramObserver, + MinMaxObserver, + PerChannelMinMaxObserver, +) from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationConfig, @@ -21,8 +25,9 @@ INT8_WEIGHT_PER_CHANNEL_QSPEC = QuantizationSpec( dtype=torch.int8, - observer_or_fake_quant_ctr=MinMaxObserver, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, qscheme=torch.per_channel_symmetric, + ch_axis=0, ) INT8_ACTIVATION_PER_TENSOR_QSPEC = QuantizationSpec( @@ -33,8 +38,9 @@ INT8_ACTIVATION_PER_CHANNEL_QSPEC = QuantizationSpec( dtype=torch.int8, - observer_or_fake_quant_ctr=HistogramObserver, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver, qscheme=torch.per_channel_affine, + ch_axis=0, ) @@ -61,7 +67,18 @@ def _get_int32_bias_qspec(node): dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max - 1, - qscheme=torch.per_tensor_symmetric, + ) + + +def _get_int32_per_channel_bias_qspec(node): + return DerivedQuantizationSpec( + derived_from=[(node.args[0], node), (node.args[1], node)], # type: ignore[list-item] + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + qscheme=torch.per_channel_symmetric, + ch_axis=0, ) @@ -75,8 +92,8 @@ def _get_int32_bias_qspec(node): INT8_PER_CHANNEL_CONFIG = QuantizationConfig( - INT8_ACTIVATION_PER_CHANNEL_QSPEC, - INT8_ACTIVATION_PER_CHANNEL_QSPEC, + INT8_ACTIVATION_PER_TENSOR_QSPEC, + INT8_ACTIVATION_PER_TENSOR_QSPEC, INT8_WEIGHT_PER_CHANNEL_QSPEC, - _get_int32_bias_qspec, + _get_int32_per_channel_bias_qspec, ) diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 1f9b06c27ec..8bfc32049ed 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -12,7 +12,9 @@ from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager from executorch.backends.cortex_m.quantizer.operator_configs import ( BINARY_OP_PATTERNS, + CONV_OP_PATTERNS, INT8_BINARY_OPS_OPERATOR_CONFIG, + INT8_CONV_OPERATOR_CONFIG, INT8_LINEAR_OPERATOR_CONFIG, ) from executorch.backends.cortex_m.quantizer.quantization_configs import ( @@ -47,12 +49,30 @@ def broadcasting_filter(self, node: Optional[Node]) -> bool: return False + def nchw_filter(self, node: Optional[Node]) -> bool: + """ + Filter function to exclude nodes that use NCHW memory format. + """ + if node is None: + return False + if [node.target] not in CONV_OP_PATTERNS: + return False + + tensor = get_first_fake_tensor(node) + if tensor is None: + return False + + return not tensor.is_contiguous(memory_format=torch.channels_last) + def __init__(self) -> None: quantizers: List[Quantizer] = [ OperatorConfigQuantizer( INT8_BINARY_OPS_OPERATOR_CONFIG, filter_fn=self.broadcasting_filter ), OperatorConfigQuantizer(INT8_LINEAR_OPERATOR_CONFIG), + OperatorConfigQuantizer( + INT8_CONV_OPERATOR_CONFIG, filter_fn=self.nchw_filter + ), InputQuantizer(INT8_PER_TENSOR_CONFIG), OutputQuantizer(INT8_PER_TENSOR_CONFIG), SharedQspecQuantizer(), diff --git a/backends/cortex_m/test/ops/test_conv.py b/backends/cortex_m/test/ops/test_conv.py new file mode 100644 index 00000000000..c6bb4815dca --- /dev/null +++ b/backends/cortex_m/test/ops/test_conv.py @@ -0,0 +1,247 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import ( + CortexMTester, + McuTestCase, + ramp_tensor, +) + + +class CortexMConv1D(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv1d(*args, **kwargs, bias=False) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2D(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=False) + self.conv.weight.data.fill_(1.0) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2DBias(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(*args, **kwargs, bias=True) + + def forward(self, x): + + return self.conv(x) + + +class CortexMConv3D(torch.nn.Module): + ops_before_transforms = {} + + ops_after_transforms = {} + + def __init__(self, *args, **kwargs): + super().__init__() + self.conv = torch.nn.Conv3d(*args, **kwargs, bias=False) + self.conv.weight.data.fill_(2.0) + + def forward(self, x): + return self.conv(x) + + +class CortexMConv2Dx3(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 3, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 8, 3, padding=1, bias=False) + self.conv2 = torch.nn.Conv2d(8, 16, 3, padding=1, bias=False) + self.conv3 = torch.nn.Conv2d(16, 8, 3, padding=1, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + + +class CortexMConv2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 1, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + } + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 8, 3, padding=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + return x + + +# in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode +test_cases = { + "conv2d": McuTestCase( + model=CortexMConv2D(2, 4, 3), + example_inputs=( + ramp_tensor(1, 5, (1, 2, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_stride": McuTestCase( + model=CortexMConv2D(3, 4, (1, 2), stride=2), + example_inputs=( + ramp_tensor(-100, 10, (3, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_padding": McuTestCase( + model=CortexMConv2D(3, 2, 3, padding=(4, 1)), + example_inputs=( + ramp_tensor(0, 1, (2, 3, 5, 5)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_dilation": McuTestCase( + model=CortexMConv2D(1, 4, 3, dilation=(2, 2)), + example_inputs=( + ramp_tensor(0, 10, (3, 1, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_groups": McuTestCase( + model=CortexMConv2D(4, 4, 1, groups=2), + example_inputs=( + ramp_tensor(0, 10, (1, 4, 1, 1)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_bias_ch_out_1": McuTestCase( + model=CortexMConv2DBias(5, 1, 1), + example_inputs=( + ramp_tensor(0, 10, (2, 5, 3, 3)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_bias_ch_out_4": McuTestCase( + model=CortexMConv2DBias(5, 4, (1, 2)), + example_inputs=( + ramp_tensor(-3, 3, (2, 5, 10, 10)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_nchw": McuTestCase( + model=CortexMConv2D(5, 5, 1), + example_inputs=(ramp_tensor(0, 10, (1, 5, 8, 8)),), + ), + "conv1d": McuTestCase( + model=CortexMConv1D(1, 1, 1), + example_inputs=(ramp_tensor(0, 10, (1, 3, 2)),), + ), + "conv3d": McuTestCase( + model=CortexMConv3D(1, 1, 1), + example_inputs=( + ramp_tensor(-1000, 1000, (2, 1, 3, 3, 3)).to( + memory_format=torch.channels_last_3d + ), + ), + ), + "conv2d_x3": McuTestCase( + model=CortexMConv2Dx3(), + example_inputs=( + ramp_tensor(0, 10, (1, 3, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "conv2d_relu": McuTestCase( + model=CortexMConv2DReLU(), + example_inputs=( + ramp_tensor(-5, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), +} + + +xfails_dialect = { + "conv2d_dilation": "NotImplementedError: 'slow_conv_dilated<>' not implemented for 'Int'", + "conv1d": "Currently not supported.", + "conv2d_nchw": "Currently not supported.", + "conv3d": "Currently not supported.", + "conv2d_relu": "Currently not supported.", +} + + +@parametrize("test_case", test_cases, xfails=xfails_dialect) +def test_dialect_conv2d(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, + test_case.model.ops_after_transforms, + qtol=1, + ) + + +xfails_implementation = { + "conv1d": "Currently not supported.", + "conv2d_nchw": "Currently not supported.", + "conv3d": "Currently not supported.", + "conv2d_relu": "Currently not supported.", +} + + +@parametrize("test_case", test_cases, xfails=xfails_implementation) +def test_implementation_conv2d(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation(qtol=1)