Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions backends/nxp/backend/ir/converter/conversion/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,22 @@ def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor
return tensor


def extend_1d_pads_to_2d(onnx_1d_pads: MutableSequence):
"""Extend the onnx 'pads' operator attribute that represents padding for a 1D kernel to 2D, by adding '0's."""
if onnx_1d_pads is not None:
onnx_1d_pads.insert(1, 0)
onnx_1d_pads.append(0)
def extend_1d_padding_to_2d(tflite_1d_padding: MutableSequence):
"""Extend the PyTorch 'padding' operator attribute that represents padding for a 1D kernel to 2D, by adding '0's."""
if tflite_1d_padding is not None:
tflite_1d_padding.append(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my specific concern. We are padding with zeros not zero point, just like the 2d case previously

Copy link
Collaborator Author

@roman-janik-nxp roman-janik-nxp Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value express the amount of padding applied to the input, not the padding value. So zero here means that the tensor will not be padded on H dimension as this is conversion from 1D tensor NWC to 2D NHWC tensor - padding also needs to be extended/converted. The amount of padding for W dim is kept.
The padding value is the zero-point added on L352 and L390 in convolution_converter.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!



def extend_1d_strides_to_2d(onnx_1d_strides: MutableSequence):
"""Extend the onnx 'strides' operator attribute that represents strides for a 1D kernel to 2D, by adding '1'."""
if onnx_1d_strides is not None:
onnx_1d_strides.append(1)
def extend_1d_stride_to_2d(tflite_1d_stride: MutableSequence):
"""Extend the PyTorch 'stride' operator attribute that represents stride for a 1D kernel to 2D, by adding '1'."""
if tflite_1d_stride is not None:
tflite_1d_stride.append(1)


def extend_1d_dilations_to_2d(onnx_1d_dilations: MutableSequence):
"""Extend the onnx 'dilations' operator attribute that represents dilations for a 1D kernel to 2D, by adding '1'."""
if onnx_1d_dilations is not None:
onnx_1d_dilations.append(1)


def extend_1d_kernel_shape_to_2d(onnx_1d_kernel_shape: MutableSequence):
"""Extend the onnx 1D 'kernel_shape' operator attribute to 2D, by adding '1'."""
if onnx_1d_kernel_shape is not None:
onnx_1d_kernel_shape.append(1)
def extend_1d_dilation_to_2d(tflite_1d_dilation: MutableSequence):
"""Extend the PyTorch 'dilation' operator attribute that represents dilation for a 1D kernel to 2D, by adding '1'."""
if tflite_1d_dilation is not None:
tflite_1d_dilation.append(1)


StridedOptions = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
from executorch.backends.nxp.backend.ir.converter.conversion import (
aten_translator,
common,
translator,
)
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
tf_lite_type_to_numpy,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
NodeConverter,
Target,
Expand All @@ -36,6 +40,7 @@
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
conv_2d_options,
depthwise_conv_2d_options,
reshape_options,
)
from torch.fx import Node
from torch.nn import Parameter
Expand Down Expand Up @@ -85,13 +90,15 @@ def _is_supported_on_target(
def _is_supported_in_IR(
node: Node, parameters_mapping: dict[str, Parameter]
) -> bool:
input_tensor_rank = len(node.meta["val"].shape)
dimensions = input_tensor_rank - 2
is_transposed = node.args[6]
output_padding = node.args[7]

if is_transposed:
return False

if output_padding != [0, 0]:
if output_padding != [0] * dimensions:
return False

if input_tensor_safe(node, 2) is None:
Expand All @@ -116,7 +123,107 @@ def _get_convolution_arguments(
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
conv_node.args
)
return stride, padding, dilation, transposed, out_padding, groups
return (
list(stride),
list(padding),
list(dilation),
transposed,
out_padding,
groups,
)

def _convert_1d_conv(
self, t_op: tflite_model.Operator, conv_params: ConvParameters
) -> list[tflite_model.Operator]:
"""Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'.
TFLite doesn't support 1D convolution, but this behaviour can be represented using
Reshape -> Conv2D -> Reshape.
The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension.
"""
# -- Calculate the shapes for equivalent 2D convolution --
conv_2d_input_shape = translator.nhc_dimensions_to_nhwc(
t_op.tmp_inputs[0].shape.vector
)
conv_2d_weight_shape = translator.nhc_dimensions_to_nhwc(
t_op.tmp_inputs[1].shape.vector
)
conv_2d_output_shape = translator.nhc_dimensions_to_nhwc(
t_op.tmp_outputs[0].shape.vector
)

# -- Generate tensors taking part in the conversion --
reshape1_input = t_op.tmp_inputs[0]

reshape1_output = self.builder.duplicate_tensor(
reshape1_input, name_suffix="_4D_"
)
reshape1_output.shape = tflite_model.Shape(conv_2d_input_shape)

reshape2_input = self.builder.duplicate_tensor(
t_op.tmp_outputs[0], name_suffix="_4D_"
)
reshape2_input.shape = tflite_model.Shape(conv_2d_output_shape)

reshape2_output = t_op.tmp_outputs[0]

pre_reshapes = []

# Extend the weights tensor to 4D
weights_tensor = t_op.tmp_inputs[1]
if tensor_has_data(weights_tensor):
# Do it statically
weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
weights_tensor.tmp_buffer.data = weights_tensor.tmp_buffer.data.reshape(
conv_2d_weight_shape
)

else:
# Add a Reshape before the weights tensor
new_weights_tensor = self.builder.duplicate_tensor(
weights_tensor, name_suffix="_4D_"
)
new_weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)

weight_reshape = tflite_model.Operator(
builtin_options=reshape_options.Reshape(conv_2d_weight_shape)
)
weight_reshape.tmp_inputs = [weights_tensor]
weight_reshape.tmp_outputs = [new_weights_tensor]

pre_reshapes.append(weight_reshape)

# Save the new weights tensor, to assign it later.
weights_tensor = new_weights_tensor

# -- Create the new operators --
reshape1 = tflite_model.Operator(
builtin_options=reshape_options.Reshape(conv_2d_input_shape)
)
reshape1.tmp_inputs = [reshape1_input]
reshape1.tmp_outputs = [reshape1_output]
pre_reshapes.append(reshape1)

reshape2 = tflite_model.Operator(
builtin_options=reshape_options.Reshape(reshape2_output.shape.vector)
)
reshape2.tmp_inputs = [reshape2_input]
reshape2.tmp_outputs = [reshape2_output]

# Assign the new input and output of the Conv2D
t_op.tmp_inputs = [reshape1_output, weights_tensor] + t_op.tmp_inputs[
2:
] # Add bias as well, if present
t_op.tmp_outputs = [reshape2_input]

# Extend all Conv attributes to 2D
common.extend_1d_stride_to_2d(conv_params.stride)
common.extend_1d_dilation_to_2d(conv_params.dilation)
common.extend_1d_padding_to_2d(conv_params.padding)

# Convert the now 2D Conv
converted_conv_ops = self._convert_2d_conv(t_op, conv_params)

return pre_reshapes + converted_conv_ops + [reshape2]

# noinspection PyPep8Naming
def _convert_unpadded_2D(
Expand Down Expand Up @@ -182,9 +289,19 @@ def _convert_2d_conv(
aten_translator.convert_padding(conv_params.padding)
)
if explicit_padding is not None:
# Need to prepend a 'Pad' operator, which adds 0s.
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
input_quantization = t_op.tmp_inputs[0].quantization
pad_value = (
None
if input_quantization is None
else np.array(input_quantization.zero_point[0]).astype(
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
)
)
conversion_result.ops_list.add_pre(
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
self.builder.create_pad_operator_before(
t_op, 0, explicit_padding, constant_value=pad_value
)
)

# DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels]
Expand Down Expand Up @@ -221,9 +338,19 @@ def _convert_2d_conv(
aten_translator.convert_padding(conv_params.padding)
)
if explicit_padding is not None:
# Need to prepend a 'Pad' operator, which adds 0s.
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
input_quantization = t_op.tmp_inputs[0].quantization
pad_value = (
None
if input_quantization is None
else np.array(input_quantization.zero_point[0]).astype(
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
)
)
conversion_result.ops_list.add_pre(
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
self.builder.create_pad_operator_before(
t_op, 0, explicit_padding, constant_value=pad_value
)
)

return conversion_result.ops_list.flatten()
Expand All @@ -237,7 +364,9 @@ def convert(self, node: Node):
conv_params = ConvParameters(stride, padding, dilation, groups)

rank = t_op.tmp_inputs[1].shape.len()
if rank == 4: # Conv2D
if rank == 3: # Conv1D
ops_to_add = self._convert_1d_conv(t_op, conv_params)
elif rank == 4: # Conv2D
ops_to_add = self._convert_2d_conv(t_op, conv_params)
else:
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
)
from executorch.backends.nxp.backend.ir.converter.conversion import aten_translator
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
tf_lite_type_to_numpy,
)
from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data
from executorch.backends.nxp.backend.ir.lib.tflite.Padding import Padding
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
Expand Down Expand Up @@ -289,9 +292,17 @@ def build_input_tensor_padding(

tfl_padding, explicit_padding = aten_translator.convert_padding(conv_params.padding)
if explicit_padding is not None:
# Must add extra 'Pad' operator
# Must add extra 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
input_quantization = t_op.tmp_inputs[0].quantization
pad_value = (
None
if input_quantization is None
else np.array(input_quantization.zero_point[0]).astype(
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
)
)
return tfl_padding, builder.create_pad_operator_before(
t_op, input_idx, explicit_padding
t_op, input_idx, explicit_padding, pad_value
)

return tfl_padding, None
Expand Down
6 changes: 3 additions & 3 deletions backends/nxp/tests/executorch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):

def to_quantized_edge_program(
model: torch.nn.Module,
input_shapes: tuple[int] | list[tuple[int]],
input_shapes: tuple[int, ...] | list[tuple[int, ...]],
operators_not_to_delegate: list[str] = None,
target="imxrt700",
neutron_converter_flavor="SDK_25_03",
Expand Down Expand Up @@ -100,7 +100,7 @@ def to_quantized_edge_program(


def to_quantized_executorch_program(
model: torch.nn.Module, input_shapes: tuple[int] | list[tuple[int]]
model: torch.nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
) -> ExecutorchProgramManager:
edge_program_manager = to_quantized_edge_program(model, input_shapes)

Expand All @@ -110,7 +110,7 @@ def to_quantized_executorch_program(


def to_edge_program(
model: nn.Module, input_shapes: tuple[int] | list[tuple[int]]
model: nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
) -> EdgeProgramManager:
if isinstance(input_shapes, list):
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
Expand Down
Loading
Loading