From 4484c2f2c156a6c730cefbbbc4950921a94ff068 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Thu, 22 Jan 2026 08:48:50 -0800 Subject: [PATCH] Title: [cortex_m] Fix linear weight layout: transpose in AOT pass, align meta/ref impl Summary: The linear path in ConvertToCortexMPass was not transposing weights unlike conv2d, causing inconsistency with the C++ runtime which expects weights in [in_features, out_features] format per CMSIS-NN. Changes: - convert_to_cortex_m_pass.py: Transpose linear weights [out, in] -> [in, out] - operators.py: Update meta to use weights.shape[1] for output dimension - operators.py: Remove .T from ref impl (weights pre-transposed by pass) - operators.py: Transpose once, pass to _compute_kernel_sum Fixes MV2 output shape mismatch: [1, 1280] -> [1, 1000] MV2 on Corstone-300/E8 with CMSIS-NN kernels This fix ensures the AOT-compiled .pte file has correctly shaped output tensors for any model using quantized_linear (MV2, ResNet, MV3, etc.). --- backends/cortex_m/ops/operators.py | 6 ++--- .../passes/convert_to_cortex_m_pass.py | 25 ++++++++++++++++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 21c8d4b2b3c..e2a9dbe883c 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -352,7 +352,7 @@ def quantized_linear_meta( activation_min, ) -> torch.Tensor: - shape = (*input.shape[:-1], weights.shape[0]) + shape = (*input.shape[:-1], weights.shape[1]) return torch.empty(shape, dtype=input.dtype, device=input.device) @@ -386,7 +386,7 @@ def quantized_linear_impl( input_reshaped = input_int32.reshape(new_shape) lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset - output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum + output = torch.mm(input_reshaped, weights_int32) + lhs_sum + kernel_sum output_shape = (*input.shape[:-1], output.shape[-1]) output_reshaped = output.reshape(output_shape) else: @@ -396,7 +396,7 @@ def quantized_linear_impl( new_shape = (prod(input.shape[:-1]), input.shape[-1]) input_reshaped = input_int32.reshape(new_shape) - output = torch.mm(input_reshaped, weights_int32.T) + output = torch.mm(input_reshaped, weights_int32) if bias is not None: output = output + bias output_shape = (*input.shape[:-1], output.shape[-1]) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index e2d3c48d2af..9e84ea10e26 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -33,14 +33,19 @@ class ConvertToCortexMPass(XNNPACKPass): by call_operator. """ - def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): + def _compute_kernel_sum( + self, weights_transposed, bias, input_offset, weight_offset + ): """ Computes the precomputed kernel sum term (bias optional) a * sum_j(wij + b) + ci for i = (1, ..., n), where j indexes the input activations. + + Args: + weights_transposed: Weights already in [in_features, out_features] format """ - weights_transposed = weights.T + # No transpose needed - weights already transposed by caller weights_int32 = weights_transposed.to(torch.int32) offset_weights = weights_int32 + weight_offset kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) @@ -110,8 +115,12 @@ def _get_linear_replacement(self, node): if len(node.args) > 2 else None ) + # Transpose weights once from PyTorch format [out_features, in_features] + # to CMSIS-NN format [in_features, out_features] + weights_transposed = weights_tensor.T.contiguous() + # Pass already-transposed weights to kernel_sum computation kernel_sum_tensor = self._compute_kernel_sum( - weights_tensor, bias_tensor, -input_zp, -weight_zp + weights_transposed, bias_tensor, -input_zp, -weight_zp ) with node.graph.inserting_after(weights): kernel_sum = create_constant_placeholder( @@ -122,9 +131,17 @@ def _get_linear_replacement(self, node): kernel_sum_tensor, ) + weights_transposed_node = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_weights_transposed", + InputKind.PARAMETER, + weights_transposed, + ) + args = ( node.args[0], - weights, + weights_transposed_node, None, kernel_sum, -input_zp,