diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index dd4b8a425425..799a32e405f4 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1917,18 +1917,20 @@ def _impl_v13(cls, bb, inputs, attr, params): # If possible, directly expand to constant shape. if isinstance(shape, relax.Constant): new_shape = shape.data.numpy().tolist() - # For some reason, onnx allows target shapes to be smaller than input shapes. - # We need to go correct it. + # ONNX Expand operator requires preserving target rank and broadcasting + # according to standard rules. Dimensions are right-aligned. data_shape = [dim.value for dim in data.struct_info.shape] - # Dimensions are right alignment. - data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape - # Fix small target shapes. - for i, s in enumerate(new_shape): - if i < len(data_shape) and s < data_shape[i]: + + # Right-align the shapes + if len(new_shape) > len(data_shape): + data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape + else: + new_shape = [1] * (len(data_shape) - len(new_shape)) + new_shape + # Fix small target shapes - if target dim is smaller than input dim + # use the input dim (ONNX-specific behavior). + for i in range(len(new_shape)): + if new_shape[i] < data_shape[i]: new_shape[i] = data_shape[i] - # If the new shape matches the input shape, no transformation is needed. - if new_shape == data_shape: - return data return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape)) # Otherwise handle dynamic shapes.