diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index e3f9bad17ef5..0e04b0936f24 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -60,6 +60,18 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { } }; // struct ExpandDimsAttrs +/*! \brief Attributes used in dynamic expand_dims operators */ +struct DynExpandDimsAttrs : public tvm::AttrsNode { + int num_newaxis; + + TVM_DECLARE_ATTRS(DynExpandDimsAttrs, "relay.attrs.DynExpandDimsAttrs") { + TVM_ATTR_FIELD(num_newaxis) + .describe("Number of axes to be inserted. Should be >= 0.") + .set_lower_bound(0) + .set_default(1); + } +}; // struct ExpandDimsAttrs + /*! \brief Attributes used in concatenate operators */ struct ConcatenateAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4d48f5796aca..233d991969f0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1462,6 +1462,26 @@ def _impl_v1(cls, inputs, attr, params): inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) return inputs[0] + @classmethod + def _impl_v12(cls, inputs, attr, params): + rank_input = len(infer_type(inputs[0]).checked_type.shape) + num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0]) + axes = relay.split(inputs[1], num_new_axis).astuple() + + result = inputs[0] + + # TODO (AndrewZhaoLuo): investigate performance issues with consecutive + # dynamic expand_dims on non-llvm targets. + for i in range(num_new_axis): + axis = relay.TupleGetItem(axes, i) + # Unpack scalar + axis = relay.reshape(axis, []) + axis = relay.If( + axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64") + ) + result = _op.expand_dims(result, axis) + return result + class Split(OnnxOpConverter): """Operator converter for Split.""" diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index de8ee0895462..c8235ec9375a 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -20,10 +20,12 @@ from tvm.runtime import convert from tvm.te.hybrid import script + from .. import op as _reg _reg.register_broadcast_schedule("dyn.broadcast_to") _reg.register_injective_schedule("dyn.reshape") +_reg.register_injective_schedule("dyn.expand_dims") _reg.register_broadcast_schedule("dyn.tile") _reg.register_injective_schedule("dyn.one_hot") _reg.register_injective_schedule("dyn.full") @@ -89,6 +91,42 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims): return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] +@script +def _expand_dims_shape_func_input_data(data, axis, ndims, num_newaxis): + out = output_tensor((ndims,), "int64") + + for i in const_range(ndims): + if i < axis: + # We multiply by a check (i < len(data.shape)) to avoid + # a constant folding mechanism leading to an overflow + out[i] = int64(data.shape[i * (i < len(data.shape))]) + elif i - num_newaxis < axis: + out[i] = int64(1) + else: + out[i] = int64( + # We can't use axis in indices as it is not constant but we can + # use negative indices (kind of, have to manually do it) + data.shape[ + (i - num_newaxis) * (i - num_newaxis >= 0) + + (i - num_newaxis + len(data.shape)) * (i - num_newaxis < 0) + ] + ) + + return out + + +@_reg.register_shape_func("dyn.expand_dims", [True, True]) +def dynamic_expand_dims_shape_func(attrs, inputs, out_ndims): + return [ + _expand_dims_shape_func_input_data( + inputs[0], + inputs[1], + out_ndims[0], + convert(attrs.num_newaxis), + ) + ] + + @script def _tile_shape_func(data, reps, ndim, tndim, rndim): out = output_tensor((tndim,), "int64") diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 2c299022bd6e..fe1a73ca231a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -96,7 +96,7 @@ def expand_dims(data, axis, num_newaxis=1): data : relay.Expr The input data to the operator. - axis : int + axis : Union[int, Expr] The axis at which the input array is expanded. Should lie in range `[-data.ndim - 1, data.ndim]`. If `axis < 0`, it is the first axis inserted; @@ -110,7 +110,13 @@ def expand_dims(data, axis, num_newaxis=1): result : relay.Expr The reshaped result. """ - return _make.expand_dims(data, axis, num_newaxis) + if isinstance(axis, int): + return _make.expand_dims(data, axis, num_newaxis) + if isinstance(axis, Expr): + # TODO (AndrewZhaoLuo): investigate performance issues with consecutive + # dynamic expand_dims on non-llvm targets. + return _dyn_make.expand_dims(data, axis, num_newaxis) + raise ValueError(f"Unknown type for axis: {type(axis)}") def transpose(data, axes=None): diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index d8ee1c84a99c..848d058f0af3 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -618,6 +618,80 @@ RELAY_REGISTER_OP("dyn.sparse_to_dense") .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr("FTVMCompute", SparseToDenseCompute); +/* relay.dyn.unsqueeze */ +TVM_REGISTER_NODE_TYPE(DynExpandDimsAttrs); + +bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(num_inputs, 2); + const auto* data_type = types[0].as(); + if (data_type == nullptr) { + ICHECK(types[0].as()) + << "expand_dims: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto* param = attrs.as(); + + // We don't know the output shape until we see the value of the axis input + int ndim = data_type->shape.size(); + Array oshape(ndim + param->num_newaxis, Any()); + + const auto* axis_type = types[1].as(); + ICHECK(axis_type->shape.size() == 0) << "Axis should be a scalar got shape " << axis_type->shape; + + // Set output shape + reporter->Assign(types[2], TensorType(oshape, data_type->dtype)); + return true; +} + +Array ExpandDimsCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + // inputs = [Input tensor, axis to expand] + ICHECK_EQ(inputs.size(), 2); + + const auto* param = attrs.as(); + + Array ishape = inputs[0]->shape; + const TensorTypeNode* out_ttype = out_type.as(); + int ndim_out = out_ttype->shape.size(); + int ndim_in = ishape.size(); + ICHECK_EQ(ndim_in + param->num_newaxis, ndim_out); + + Array newshape; + for (auto val : out_ttype->shape) { + // These vars will be populated by the VM executor with the results + // of the shape_func for the op. + newshape.push_back(val.as()->ToVar()); + } + + return {topi::reshape(inputs[0], newshape)}; +} + +Expr MakeExpandDims(Expr data, Expr axis_tensor, int num_newaxis) { + auto attrs = make_object(); + attrs->num_newaxis = num_newaxis; + static const Op& op = Op::Get("dyn.expand_dims"); + return Call(op, {data, axis_tensor}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.expand_dims").set_body_typed(MakeExpandDims); + +RELAY_REGISTER_OP("dyn.expand_dims") + .describe(R"code(Insert one new axis at the position given by `axis` + +- **data**: The input data to the operator. +- **axis**: The axis to insert a new dimension + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("axis", "Tensor", "The axis to insert at a dimension.") + .set_support_level(3) + .add_type_rel("DynamicExpandDims", ExpandDimsRel) + .set_attr("FTVMCompute", ExpandDimsCompute) + .set_attr("TOpPattern", kInjective); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 91d3911da530..1373c1c56eee 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5015,16 +5015,13 @@ def verify_eyelike(indata): "test_training_dropout_mask", "test_training_dropout_zero_ratio", "test_training_dropout_zero_ratio_mask", - "test_unique_sorted_with_axis", - "test_unique_sorted_with_axis_3d", - "test_unique_sorted_with_negative_axis", - "test_unsqueeze_axis_0", - "test_unsqueeze_axis_1", - "test_unsqueeze_axis_2", - "test_unsqueeze_negative_axes", + # These unsqueeze tests work, but take 2+ hrs to run "test_unsqueeze_three_axes", "test_unsqueeze_two_axes", "test_unsqueeze_unsorted_axes", + "test_unique_sorted_with_axis", + "test_unique_sorted_with_axis_3d", + "test_unique_sorted_with_negative_axis", "test_upsample_nearest", ] diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index d2ad5a47f15b..8c57e1dc4a9f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -19,11 +19,10 @@ import numpy as np import pytest import tvm -from tvm import te -from tvm import relay +import tvm.testing +from tvm import relay, te from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type -import tvm.testing def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets()): @@ -93,6 +92,35 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +@tvm.testing.uses_gpu +def test_dyn_expand_dims(): + def verify_expand_dims( + dshape, dtype, oshape, axis, num_newaxis, target_device=tvm.testing.enabled_targets() + ): + # Use 1 to avoid issues with invalid buffer sizes + x = relay.Var("x", relay.TensorType(dshape, dtype)) + y = relay.var("axis", shape=[], dtype="int64") + z = relay.expand_dims(x, axis=y, num_newaxis=num_newaxis) + func = relay.Function([x, y], z) + + data_np = np.random.uniform(size=dshape).astype(dtype) + axis_np = np.array(axis).astype("int64") + ref_res = data_np.reshape(oshape) + verify_func(func, [data_np, axis_np], ref_res, target_device=target_device) + + for dtype in ["float16", "float32"]: + verify_expand_dims((2, 2), dtype, (2, 2, 1), 2, 1) + verify_expand_dims((2, 2), dtype, (2, 1, 2), 1, 1) + verify_expand_dims((2, 2), dtype, (1, 2, 2), 0, 1) + + # TODO (AndrewZhaoLuo): investigate why runtimes in non-llvm are extremely slow + # for multiple new axis + llvm_target_only = [x for x in tvm.testing.enabled_targets() if "llvm" in x] + verify_expand_dims((2, 2), dtype, (2, 2, 1, 1), 2, 2, target_device=llvm_target_only) + verify_expand_dims((2, 2), dtype, (2, 1, 1, 1, 2), 1, 3, target_device=llvm_target_only) + verify_expand_dims((2, 2), dtype, (1, 1, 1, 1, 2, 2), 0, 4, target_device=llvm_target_only) + + @tvm.testing.uses_gpu def test_dyn_tile(): def verify_tile(dshape, reps):