From 59fdd7278783197ce2fa7241443403aeda0d6b23 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 27 Jan 2026 10:14:31 -0800 Subject: [PATCH] Revert "[cortex_m] Fix linear weight layout: transpose in AOT pass, align meta/ref impl (#16782)" This reverts commit 06f10b9d71f44bc5d30fc4ab456a332d0b5bddd4. --- backends/cortex_m/ops/operators.py | 6 ++--- .../passes/convert_to_cortex_m_pass.py | 25 +++---------------- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index b4a36bc7258..0bc9702d680 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -352,7 +352,7 @@ def quantized_linear_meta( activation_min, ) -> torch.Tensor: - shape = (*input.shape[:-1], weights.shape[1]) + shape = (*input.shape[:-1], weights.shape[0]) return torch.empty(shape, dtype=input.dtype, device=input.device) @@ -386,7 +386,7 @@ def quantized_linear_impl( input_reshaped = input_int32.reshape(new_shape) lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset - output = torch.mm(input_reshaped, weights_int32) + lhs_sum + kernel_sum + output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum output_shape = (*input.shape[:-1], output.shape[-1]) output_reshaped = output.reshape(output_shape) else: @@ -396,7 +396,7 @@ def quantized_linear_impl( new_shape = (prod(input.shape[:-1]), input.shape[-1]) input_reshaped = input_int32.reshape(new_shape) - output = torch.mm(input_reshaped, weights_int32) + output = torch.mm(input_reshaped, weights_int32.T) if bias is not None: output = output + bias output_shape = (*input.shape[:-1], output.shape[-1]) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 95c10369009..8da0e720036 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -33,19 +33,14 @@ class ConvertToCortexMPass(XNNPACKPass): by call_operator. """ - def _compute_kernel_sum( - self, weights_transposed, bias, input_offset, weight_offset - ): + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): """ Computes the precomputed kernel sum term (bias optional) a * sum_j(wij + b) + ci for i = (1, ..., n), where j indexes the input activations. - - Args: - weights_transposed: Weights already in [in_features, out_features] format """ - # No transpose needed - weights already transposed by caller + weights_transposed = weights.T weights_int32 = weights_transposed.to(torch.int32) offset_weights = weights_int32 + weight_offset kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) @@ -115,12 +110,8 @@ def _get_linear_replacement(self, node): if len(node.args) > 2 else None ) - # Transpose weights once from PyTorch format [out_features, in_features] - # to CMSIS-NN format [in_features, out_features] - weights_transposed = weights_tensor.T.contiguous() - # Pass already-transposed weights to kernel_sum computation kernel_sum_tensor = self._compute_kernel_sum( - weights_transposed, bias_tensor, -input_zp, -weight_zp + weights_tensor, bias_tensor, -input_zp, -weight_zp ) with node.graph.inserting_after(weights): kernel_sum = create_constant_placeholder( @@ -131,17 +122,9 @@ def _get_linear_replacement(self, node): kernel_sum_tensor, ) - weights_transposed_node = create_constant_placeholder( - self.exported_program, - node.graph, - node.name + "_weights_transposed", - InputKind.PARAMETER, - weights_transposed, - ) - args = ( node.args[0], - weights_transposed_node, + weights, None, kernel_sum, -input_zp,