diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 0bc9702d680..b4a36bc7258 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -352,7 +352,7 @@ def quantized_linear_meta( activation_min, ) -> torch.Tensor: - shape = (*input.shape[:-1], weights.shape[0]) + shape = (*input.shape[:-1], weights.shape[1]) return torch.empty(shape, dtype=input.dtype, device=input.device) @@ -386,7 +386,7 @@ def quantized_linear_impl( input_reshaped = input_int32.reshape(new_shape) lhs_sum = torch.sum(input_reshaped, dim=-1, keepdim=True) * filter_offset - output = torch.mm(input_reshaped, weights_int32.T) + lhs_sum + kernel_sum + output = torch.mm(input_reshaped, weights_int32) + lhs_sum + kernel_sum output_shape = (*input.shape[:-1], output.shape[-1]) output_reshaped = output.reshape(output_shape) else: @@ -396,7 +396,7 @@ def quantized_linear_impl( new_shape = (prod(input.shape[:-1]), input.shape[-1]) input_reshaped = input_int32.reshape(new_shape) - output = torch.mm(input_reshaped, weights_int32.T) + output = torch.mm(input_reshaped, weights_int32) if bias is not None: output = output + bias output_shape = (*input.shape[:-1], output.shape[-1]) diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 8da0e720036..95c10369009 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -33,14 +33,19 @@ class ConvertToCortexMPass(XNNPACKPass): by call_operator. """ - def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): + def _compute_kernel_sum( + self, weights_transposed, bias, input_offset, weight_offset + ): """ Computes the precomputed kernel sum term (bias optional) a * sum_j(wij + b) + ci for i = (1, ..., n), where j indexes the input activations. + + Args: + weights_transposed: Weights already in [in_features, out_features] format """ - weights_transposed = weights.T + # No transpose needed - weights already transposed by caller weights_int32 = weights_transposed.to(torch.int32) offset_weights = weights_int32 + weight_offset kernel_sum = torch.sum(offset_weights, dim=0, keepdim=True, dtype=torch.int32) @@ -110,8 +115,12 @@ def _get_linear_replacement(self, node): if len(node.args) > 2 else None ) + # Transpose weights once from PyTorch format [out_features, in_features] + # to CMSIS-NN format [in_features, out_features] + weights_transposed = weights_tensor.T.contiguous() + # Pass already-transposed weights to kernel_sum computation kernel_sum_tensor = self._compute_kernel_sum( - weights_tensor, bias_tensor, -input_zp, -weight_zp + weights_transposed, bias_tensor, -input_zp, -weight_zp ) with node.graph.inserting_after(weights): kernel_sum = create_constant_placeholder( @@ -122,9 +131,17 @@ def _get_linear_replacement(self, node): kernel_sum_tensor, ) + weights_transposed_node = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_weights_transposed", + InputKind.PARAMETER, + weights_transposed, + ) + args = ( node.args[0], - weights, + weights_transposed_node, None, kernel_sum, -input_zp,