diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f3a5e9098de8..dfd092314cbc 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2119,13 +2119,18 @@ def _impl_v18(cls, bb, inputs, attr, params): # Define relax implementation. if roi is not None: - roi = relax.op.concat( - [ - relax.op.strided_slice(roi, axes=[0], begin=[2], end=[ndims]), - relax.op.strided_slice(roi, axes=[0], begin=[ndims + 2], end=[2 * ndims]), - ], - axis=0, - ) + if isinstance(roi, relax.Constant): + roi = roi.data.numpy().tolist() + else: + roi = relax.op.concat( + [ + relax.op.strided_slice(roi, axes=[0], begin=[2], end=[ndims]), + relax.op.strided_slice(roi, axes=[0], begin=[ndims + 2], end=[2 * ndims]), + ], + axis=0, + ) + # TODO The backend C++ func resize2d does not support dynamic ROI for now. + raise NotImplementedError("Dynamic ROI is not supported in resize2d for now.") else: roi = [0.0] * 4 diff --git a/python/tvm/relax/op/image/image.py b/python/tvm/relax/op/image/image.py index 6bec22161dbc..afadbf35fb6b 100644 --- a/python/tvm/relax/op/image/image.py +++ b/python/tvm/relax/op/image/image.py @@ -104,6 +104,10 @@ def resize2d( roi = (0.0, 0.0, 0.0, 0.0) # type: ignore elif isinstance(roi, float): roi = (roi, roi, roi, roi) # type: ignore + elif isinstance(roi, (tuple, list)): + roi = tuple(val if isinstance(val, float) else float(val) for val in roi) + else: + raise NotImplementedError(f"Unsupported roi type {type(roi)}") if isinstance(size, (int, PrimExpr)): size = (size, size) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 0d532e07fc33..74d75f5abdd7 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2359,8 +2359,34 @@ def verify_tile(in_shape, repeats, out_shape): verify_tile(x.shape, repeats, z_array.shape) -def test_resize(): - resize_node = helper.make_node("Resize", ["X", "", "scales"], ["Y"], mode="cubic") +def _generate_roi_cases(): + # Base case when with_roi is False + roi_list = [ + pytest.param(False, None, id="no_roi"), + ] + + # Valid when with_roi is True + roi_cases = [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0], + [0.1, 0.1, 0.9, 0.9], + [0.2, 0.2, 0.8, 0.8], + [0.3, 0.3, 0.7, 0.7], + [0.4, 0.4, 0.6, 0.6], + [0.5, 0.5, 0.5, 0.5], + [0.1, 0.2, 0.9, 0.8], + ] + for roi in roi_cases: + roi_list.append(pytest.param(True, roi, id=f"roi_{'_'.join(str(x) for x in roi)}")) + + return roi_list + + +@pytest.mark.parametrize("with_roi, roi_list", _generate_roi_cases()) +def test_resize(with_roi, roi_list): + resize_node = helper.make_node( + "Resize", ["X", "roi" if with_roi else "", "scales"], ["Y"], mode="cubic" + ) graph = helper.make_graph( [resize_node], @@ -2370,6 +2396,11 @@ def test_resize(): ], initializer=[ helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0]), + *( + [helper.make_tensor("roi", TensorProto.FLOAT, [4], [0.0, 0.0, 0.0, 0.0])] + if with_roi + else [] + ), ], outputs=[ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64]),