From 80f1ce90f9043b92b4f5b3d31ef7564d089d6062 Mon Sep 17 00:00:00 2001 From: Anurag Singh <10385586+singh20anurag@users.noreply.github.com> Date: Tue, 25 Mar 2025 12:55:03 -0500 Subject: [PATCH] [ONNX] Fix Expand operator to properly handle target shapes This fixes issue #17746 where the ONNX Expand operator was not correctly expanding tensors to higher dimensions. The issue manifested when a downstream ArgMin operation received a tensor with fewer dimensions than expected, causing an 'axis out of bounds' error. Specifically: 1. The Expand op was incorrectly skipping the broadcast when input and target shapes had the same values but different ranks 2. This caused a tensor with shape [5,60] to remain [5,60] when it should have been expanded to [1,1,5,60] 3. The subsequent ArgMin op with axis=2 then failed as the tensor only had 2 dimensions instead of the expected 4 The fix ensures that Expand always broadcasts to the target shape, preserving the rank specified in the ONNX model. This allows downstream operations to work with the correct tensor dimensions. Fixes #17746 --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index dd4b8a425425..799a32e405f4 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1917,18 +1917,20 @@ def _impl_v13(cls, bb, inputs, attr, params): # If possible, directly expand to constant shape. if isinstance(shape, relax.Constant): new_shape = shape.data.numpy().tolist() - # For some reason, onnx allows target shapes to be smaller than input shapes. - # We need to go correct it. + # ONNX Expand operator requires preserving target rank and broadcasting + # according to standard rules. Dimensions are right-aligned. data_shape = [dim.value for dim in data.struct_info.shape] - # Dimensions are right alignment. - data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape - # Fix small target shapes. - for i, s in enumerate(new_shape): - if i < len(data_shape) and s < data_shape[i]: + + # Right-align the shapes + if len(new_shape) > len(data_shape): + data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape + else: + new_shape = [1] * (len(data_shape) - len(new_shape)) + new_shape + # Fix small target shapes - if target dim is smaller than input dim + # use the input dim (ONNX-specific behavior). + for i in range(len(new_shape)): + if new_shape[i] < data_shape[i]: new_shape[i] = data_shape[i] - # If the new shape matches the input shape, no transformation is needed. - if new_shape == data_shape: - return data return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape)) # Otherwise handle dynamic shapes.