diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 611f4348d55e..cbd633324a75 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -909,7 +909,7 @@ def _impl_v14(cls, bb, inputs, attr, params): if len(inputs) > 1: k = get_constant(inputs[1], params) if isinstance(k, relax.Constant): - k = int(k.data.numpy()[0]) + k = int(k.data.numpy().item()) else: raise ValueError("Currently only support constant k for Trilu op.") else: @@ -1588,6 +1588,16 @@ def _impl_v13(cls, bb, inputs, attr, params): return bb.emit_te(topi.split, inputs[0], indices, axis=attr.get("axis", 0)) +def get_prim_value_list(values): + new_values = [] + for v in list(values): + if isinstance(v, relax.expr.PrimExpr): + new_values.append(relax.PrimValue(v)) + else: + new_values.append(v) + return new_values + + class Slice(OnnxOpConverter): """Converts an onnx Splice node into an equivalent Relax expression.""" @@ -1641,7 +1651,12 @@ def _impl_v13(cls, bb, inputs, attr, params): assume_inbound = not all( [isinstance(param, (tir.IntImm, int)) for param in [*starts, *ends, *steps]] ) - # return relax.op.strided_slice(data, axes, starts, ends, steps) + + # Converting PrimExpr to PrimValue since relax.op.strided_slice does not accept PrimExpr + starts = get_prim_value_list(starts) + ends = get_prim_value_list(ends) + steps = get_prim_value_list(steps) + return relax.op.strided_slice( data, axes, starts, ends, steps, assume_inbound=assume_inbound ) @@ -1730,9 +1745,21 @@ class Expand(OnnxOpConverter): def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] shape = inputs[1] - if isinstance(shape, relax.ShapeExpr): - return relax.op.broadcast_to(data, shape) + data_shape = list(data.struct_info.shape) + target_shape = list(shape.values) + data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape + assert len(data_shape) == len(target_shape) + # Fix small target shapes or target shapes assigned to -1 + for i, s in enumerate(target_shape): + if isinstance(s, tvm.tir.IntImm) and ( + (isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i]) + or s.value == -1 + ): + target_shape[i] = data_shape[i] + if target_shape == data_shape: + return data + return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape)) # If possible, directly expand to constant shape. if isinstance(shape, relax.Constant): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 9faa441138fc..c130bf43730b 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1507,10 +1507,6 @@ def test_topk(axis: int, largest: int): @pytest.mark.parametrize("dynamic", [False, True]) def test_expand(dynamic): - if dynamic: - # TODO: Support dynamic shape for Expand - pytest.skip("Dynamic expand is not supported yet") - def _test_expand(name, data, shape, ref_data): shape_array = np.array(shape) shape_node = onnx.helper.make_node( @@ -1541,17 +1537,43 @@ def _test_expand(name, data, shape, ref_data): model = helper.make_model(graph, producer_name=name) check_correctness(model, inputs={"in": data}) - in_shape = (3, 1) - shape = (3, 4) - data = np.random.uniform(size=in_shape).astype(np.float32) - ref_data = np.tile(data, 4) - _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data) - - in_shape = (3, 1) - shape = (1, 3, 4) - data = np.random.uniform(size=in_shape).astype(np.float32) - ref_data = np.tile(data, (1, 1, 4)) - _test_expand("expand_with_diff_dim", data, shape, ref_data) + def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data): + shape_node = onnx.helper.make_node("Shape", inputs=["in_2"], outputs=["shape"]) + expand_node = helper.make_node("Expand", ["in", "shape"], ["out"]) + in_shape = list(data.shape) + out_shape = list(ref_data.shape) + graph = helper.make_graph( + [shape_node, expand_node], + "expand_test", + inputs=[ + helper.make_tensor_value_info("in", TensorProto.FLOAT, in_shape), + helper.make_tensor_value_info("in_2", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)], + ) + + model = helper.make_model(graph, producer_name=name) + check_correctness(model, inputs={"in": data, "in_2": shape_data}) + + if not dynamic: + in_shape = (3, 1) + shape = (3, 4) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, 4) + _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data) + + in_shape = (3, 1) + shape = (1, 3, 4) + data = np.random.uniform(size=in_shape).astype(np.float32) + ref_data = np.tile(data, (1, 1, 4)) + _test_expand("expand_with_diff_dim", data, shape, ref_data) + else: + in_shape = (1, 32, 32) + shape = ("batch", 32, 32) + data = np.random.uniform(size=in_shape).astype(np.float32) + shape_data = np.random.uniform(size=(64, 32, 32)).astype(np.float32) + ref_data = np.tile(data, (64, 1, 1)) + _test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, shape_data, shape, ref_data) # TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed.