diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h new file mode 100644 index 000000000000..bd6ae17bcf1c --- /dev/null +++ b/include/tvm/relax/attrs/manipulate.h @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/manipulate.h + * \brief Attributes for tensor manipulation operators. + */ +#ifndef TVM_RELAX_ATTRS_MANIPULATE_H_ +#define TVM_RELAX_ATTRS_MANIPULATE_H_ + +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in concat operators */ +struct ConcatAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(ConcatAttrs, "relax.attrs.ConcatAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`."); + } +}; // struct ConcatAttrs + +/*! \brief Attributes used in expand_dims operators */ +struct ExpandDimsAttrs : public tvm::AttrsNode { + Array axis; + + TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relax.attrs.ExpandDimsAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axes at which the input array are expanded. " + "All values are required to lie in range `[-data.ndim - 1, data.ndim]`, " + "with the convention of negative indexing."); + } +}; // struct ExpandDimsAttrs + +/*! \brief Attributes used in layout_transform operator */ +struct LayoutTransformAttrs : public tvm::AttrsNode { + tir::IndexMap index_map; + // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This + // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. + Optional pad_value; + + TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { + TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); + TVM_ATTR_FIELD(pad_value).describe( + "The specific value to be used to pad if the layout transform would result in implicit " + "padding. If not specified, the compiler is free to choose any value."); + } +}; // struct LayoutTransformAttrs + +/*! \brief Attributes used in permute_dims operator */ +struct PermuteDimsAttrs : public tvm::AttrsNode { + Optional> axes; + + TVM_DECLARE_ATTRS(PermuteDimsAttrs, "relax.attrs.PermuteDimsAttrs") { + TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified."); + } +}; // struct PermuteDimsAttrs + +/*! \brief Attributes used in split operator */ +struct SplitAttrs : public tvm::AttrsNode { + ObjectRef indices_or_sections; + int axis; + + TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") { + TVM_ATTR_FIELD(indices_or_sections) + .describe("The input array of indices or the number of split sections."); + TVM_ATTR_FIELD(axis).describe("The axis to be splitted"); + } +}; // struct SplitAttrs + +/*! \brief Attributes used in squeeze operators */ +struct SqueezeAttrs : public tvm::AttrsNode { + Optional> axis; + + TVM_DECLARE_ATTRS(SqueezeAttrs, "relax.attrs.SqueezeAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axis to squeeze in the input tensor." + "If `axis = None`, all axis of dimension 1 get squeezed;" + "Else, the dimension in axes get squeezed." + "It is an error if an axis does not has dimension 1."); + } +}; // struct SqueezeAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_MANIPULATE_H_ diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index fa9c81522596..a46c62e1f12b 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -15,18 +15,161 @@ # specific language governing permissions and limitations # under the License. """Manipulation operators.""" -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union, Callable from tvm.ir.expr import PrimExpr - +from tvm.tir import IntImm, FloatImm, IndexMap from . import _ffi_api -from ..expr import Expr +from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple PrimExprLike = Union[int, PrimExpr] +def broadcast_to(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Broadcasts a tensor to a specified shape. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + shape : Union[Tuple[PrimExprLike], Expr] + The target shape. + + Returns + ------- + result : relax.Expr + The broadcasted tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.broadcast_to(x, shape) # type: ignore + + +def concat(tensors: Union[Expr, List[Expr]], axis: Optional[int] = 0) -> Expr: + """Concatenate the input tensors along the given axis. + + Parameters + ---------- + tensors : Union[relax.Expr, List[relax.Expr]] + An Expr in Tuple type, containing the tensors to be concatenated, + or a list of Tensors. + + axis : Optional[int] + The axis along which the tensors are concatenated. + If `axis` is `None`, the input tensor is required to be flattened before concatenation. + + Returns + ------- + result: relax.Expr + The concatenated tensor. + """ + if isinstance(tensors, (list, tuple)): + tensors = RxTuple(tensors) + return _ffi_api.concat(tensors, axis) # type: ignore + + +def expand_dims(x: Expr, axis: Union[int, List[int]]) -> Expr: + """Insert new axes at the positions given by `axis`. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axis : Union[int, List[int]] + The axes at which the input array are expanded. + All values are required to lie in range `[-data.ndim - 1, data.ndim]`, with the convention + of negative indexing. + + Returns + ------- + result : relax.Expr + The transformed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.expand_dims(x, axis) # type: ignore + + +def flatten(x: Expr) -> Expr: + """Flatten all the tensor dimensions into one. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + Returns + ------- + result : relax.Expr + The flattened result. + """ + return _ffi_api.flatten(x) # type: ignore + + +def layout_transform( + x: Expr, + index_map: Union[Callable, IndexMap], + pad_value: Optional[Union[int, float, PrimValue]] = None, +): + """Modifies the layout of a tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor to the operator. + + index_map : Union[Callable, IndexMap] + The transformation to apply. + + pad_value : Optional[Union[int, float, PrimValue]] + The value used for padding if the transformation results in implicit padding. + If not specified, any value can be used. + + Returns + ------- + result : relax.Expr + The transformed tensor. + """ + if callable(index_map): + index_map = IndexMap.from_func(index_map) + x_dtype = x.checked_type.dtype + + # Explicitly convert python int/float pad_value to the x's type. If the default behavior + # is applied, it would be converted to int32/float32, which may not match the x's type. + if pad_value is None: + pass + elif not isinstance(pad_value, PrimValue): + if "int" in x_dtype and isinstance(pad_value, int): + pad_value = IntImm(x_dtype, pad_value) + elif "float" in x_dtype and (isinstance(pad_value, (int, float))): + pad_value = FloatImm(x_dtype, float(pad_value)) + pad_value = PrimValue(pad_value) + return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore + + +def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: + """Permutes the dimensions of an array. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axes : Optional[List[int]] + The target axes order, reverse order if not specified. + + Returns + ------- + result : relax.Expr + The transposed result. + """ + return _ffi_api.permute_dims(x, axes) # type: ignore + + def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: """Reshape the input array. @@ -60,3 +203,61 @@ def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: compile-time, an error will be thrown. """ return _ffi_api.reshape(x, shape) # type: ignore + + +def split( + x: Expr, + indices_or_sections: Union[int, List[PrimExprLike]], + axis: int = 0, +) -> Expr: + """Split input tensor along axis by sections or indices. + + If indices_or_sections is an integer, the input will be divided equally + along given axis (if possible). Last section will be smaller if the tensor + size along the given dimension is not divisible by the integer. + + If indices_or_sections is a tuple of mixture of int or PrimExpr, + the entries indicate the indices where along axis the array is split. + + Parameters + ---------- + x : relax.Expr + The tensor to be split. + + indices_or_sections : Union[int, List[PrimExprLike]] + Indices or sections to split into. Accepts an int or a list. + + axis : int + The axis over which to split. + + Returns + ------- + ret : relax.Expr + The computed result. + """ + if isinstance(indices_or_sections, int): + indices_or_sections = IntImm("int64", indices_or_sections) + return _ffi_api.split(x, indices_or_sections, axis) # type: ignore + + +def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr: + """Squeeze axes in the array. + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + axis : Optional[Union[int, List[int]] + The set of axes to remove. + If axis = None, remove all axis of dimensions 1. + If any specified axis has dimension that does not equal 1, it is an error. + + Returns + ------- + result : relax.Expr + The squeezed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.squeeze(x, axis) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 3a7ed427f9bf..efad5d98f01a 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -89,6 +89,36 @@ class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" +@tvm._ffi.register_object("relax.attrs.ConcatAttrs") +class ConcatAttrs(Attrs): + """Attributes for concat operator""" + + +@tvm._ffi.register_object("relax.attrs.ExpandDimsAttrs") +class ExpandDimsAttrs(Attrs): + """Attributes for expand_dims operator""" + + +@tvm._ffi.register_object("relax.attrs.PermuteDimsAttrs") +class PermuteDimsAttrs(Attrs): + """Attributes for permute_dims operator""" + + +@tvm._ffi.register_object("relax.attrs.SplitAttrs") +class SplitAttrs(Attrs): + """Attributes used in split operator""" + + +@tvm._ffi.register_object("relax.attrs.SqueezeAttrs") +class SqueezeAttrs(Attrs): + """Attributes for squeeze operator""" + + +@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") +class LayoutTransformAttrs(Attrs): + """Attributes used in layout_transform operator""" + + @tvm._ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index b779bdac9c13..da6010190004 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -39,17 +39,21 @@ add, assert_op, astype, + broadcast_to, builtin, call_builtin_with_ctx, call_tir, ceil, clip, + concat, cos, cosh, divide, equal, ewise_fma, exp, + expand_dims, + flatten, floor, floor_divide, full, @@ -61,6 +65,7 @@ isfinite, isinf, isnan, + layout_transform, less, less_equal, linear, @@ -75,6 +80,7 @@ negative, not_equal, null_value, + permute_dims, ones, ones_like, print, @@ -91,11 +97,11 @@ sign, sin, sinh, + split, square, + squeeze, sqrt, - strided_slice, subtract, - take, tan, tanh, tril, @@ -472,12 +478,14 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "arg", "assert_op", "astype", + "broadcast_to", "builtin", "call_packed", "call_tir", "call_builtin_with_ctx", "ceil", "clip", + "concat", "cos", "cosh", "const", @@ -489,6 +497,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "equal", "ewise_fma", "exp", + "expand_dims", + "flatten", "floor", "floor_divide", "full", @@ -505,6 +515,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "isfinite", "isinf", "isnan", + "layout_transform", "less", "less_equal", "linear", @@ -522,6 +533,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "ones", "ones_like", "output", + "permute_dims", "prim_value", "print", "prod", @@ -537,7 +549,9 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "sign", "sin", "sinh", + "split", "square", + "squeeze", "sqrt", "str", "strided_slice", diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2088a8306e7a..8ce2a541da53 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -32,6 +32,310 @@ namespace tvm { namespace relax { +/* relax.broadcast_to */ +Expr broadcast_to(Expr x, Expr shape) { + static const Op& op = Op::Get("relax.broadcast_to"); + return Call(op, {std::move(x), std::move(shape)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(broadcast_to); + +StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "broadcast_to should take 2 arguments."); + } + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* tgt_shape_sinfo = GetStructInfoAs(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to requires the input data to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (tgt_shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to requires the input new shape to be Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + if (!data_sinfo->IsUnknownNdim() && !tgt_shape_sinfo->IsUnknownNdim() && + tgt_shape_sinfo->ndim < data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "broadcast_to expects the input shape to have the number of ndim at least " + "as the input tensor's. However, the given tensor has ndim " + << data_sinfo->ndim << " while the target shape has ndim " + << tgt_shape_sinfo->ndim); + } + + // Trust the input target shape when there is no possibility to do any compile-time check. + if (!data_sinfo->shape.defined()) { + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); + } + ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); + if (!shape_sinfo->values.defined() || !tgt_shape_sinfo->values.defined()) { + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + Array old_shape_value = shape_sinfo->values.value(); + Array tgt_shape_value = tgt_shape_sinfo->values.value(); + int old_ndim = old_shape_value.size(); + int tgt_ndim = tgt_shape_value.size(); + for (int i = 0; i < old_ndim; ++i) { + PrimExpr old_len = old_shape_value[old_ndim - i - 1]; + PrimExpr tgt_len = tgt_shape_value[tgt_ndim - i - 1]; + const auto* old_len_int = old_len.as(); + if (old_len_int != nullptr && old_len_int->value == 1) { + continue; + } else if (analyzer->CanProve(old_len != tgt_len)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "broadcast_to expects the input tensor shape is broadcastable to the target shape. " + "The target shape at dim " + << tgt_ndim - i - 1 << " is " << tgt_len << " while the input tensor shape at dim " + << old_ndim - i - 1 << " is " << old_len << ", which are not equal."); + } + // Todo(relax-team): revisit here for better check on if the tensor length + // is consistent with the length in the given shape. + } + return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.broadcast_to") + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The target shape.") + .set_attr("FInferStructInfo", InferStructInfoBroadcastTo); + +/* relax.concat */ +TVM_REGISTER_NODE_TYPE(ConcatAttrs); + +Expr concat(Expr tensors, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.concat"); + return Call(op, {std::move(tensors)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.concat").set_body_typed(concat); + +Array GetTensorSInfoFromTuple(const Call& call, const BlockBuilder& ctx, + const Expr& expr) { + const auto* tuple_sinfo = GetStructInfoAs(expr); + if (tuple_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << call->op + << " expects the input to be a Tuple of Tensors. However, the given input is " + << expr->struct_info_->GetTypeKey()); + } + + Array tensor_sinfo; + tensor_sinfo.reserve(tuple_sinfo->fields.size()); + for (StructInfo field_sinfo : tuple_sinfo->fields) { + const auto* field_tensor_sinfo = field_sinfo.as(); + if (field_tensor_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " + << expr->struct_info_); + } + tensor_sinfo.push_back(GetRef(field_tensor_sinfo)); + } + return tensor_sinfo; +} + +Optional> CheckConcatOutputShape(const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, + int axis) { + bool shape_unknown = false; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr concat_sum = IntImm(DataType::Int(64), 0); + for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { + // For the specified axis, we compute the sum of shape value over each tensor. + if (d == axis) { + for (Array shape_value : shape_values) { + concat_sum += shape_value[d]; + } + continue; + } + + // For other axes, we check the equality of all tensors' shape values, to ensure safety. + for (int i = 1; i < static_cast(shape_values.size()); ++i) { + if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects the input tensors to have the same shape on every " + "dimension except the one indicated by the input axis. However, the " + "input contains tensors whose shapes on dimension " + << d << " is " << shape_values[0][d] << " and " << shape_values[i][d]); + } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { + shape_unknown = true; + } + } + } + + if (shape_unknown) { + return NullOpt; + } + Array output_shape = shape_values[0]; + output_shape.Set(axis, concat_sum); + return output_shape; +} + +StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument"); + } + Array tensor_sinfo = GetTensorSInfoFromTuple(call, ctx, call->args[0]); + if (tensor_sinfo.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat op expects at least one tensor in the input Tuple. However, the " + "given input Tuple is empty."); + } + + const auto* attrs = call->attrs.as(); + int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1; + DataType output_dtype = DataType::Void(); + bool shape_unknown = false; + bool is_void_dtype = false; + std::vector> shape_values; + shape_values.reserve(tensor_sinfo.size()); + + for (TensorStructInfo sinfo : tensor_sinfo) { + // Update the output dtype. + if (sinfo->dtype.is_void()) { + is_void_dtype = true; + } else if (output_dtype.is_void()) { + output_dtype = sinfo->dtype; + } else if (sinfo->dtype != output_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have the same dtype. However, the " + "input contains tensors with dtype " + << output_dtype << " and " << sinfo->dtype); + } + + // Update the output ndim. + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim 1 when the input axis is undefined. + if (output_ndim == kUnknownNDim) { + output_ndim = sinfo->ndim; + } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Concat expects all input tensors to have same ndim. However, the " + "input contains tensors with ndim " + << output_ndim << " and " << sinfo->ndim); + } + + // Update the shape values for best effort check. + const auto* shape_expr = sinfo->shape.as(); + if (shape_expr != nullptr) { + shape_values.push_back(shape_expr->values); + continue; + } + shape_unknown = true; + + if (!sinfo->shape.defined()) { + continue; + } + // Keep the shape value for equality check. + ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); + if (shape_sinfo->values.defined()) { + shape_values.push_back(shape_sinfo->values.value()); + } + } + + if (is_void_dtype) { + output_dtype = DataType::Void(); + } + if (output_ndim == kUnknownNDim) { + return tensor_sinfo.size() == 1 ? tensor_sinfo[0] : TensorStructInfo(output_dtype, output_ndim); + } + + int axis = + attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; + // If there is only one input tensor, no action is needed. + if (tensor_sinfo.size() == 1) { + return tensor_sinfo[0]; + } + if (shape_values.empty()) { + return TensorStructInfo(output_dtype, output_ndim); + } + + // As long as the there is known shape value, we will do the best effort check to ensure safety. + Optional> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis); + + if (shape_unknown || !output_shape.defined()) { + return TensorStructInfo(output_dtype, output_ndim); + } else { + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } +} + +TVM_REGISTER_OP("relax.concat") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors.") + .set_attr("FInferStructInfo", InferStructInfoConcat); + +/* relax.expand_dims */ +TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); + +Expr expand_dims(Expr x, Array axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.expand_dims"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.expand_dims").set_body_typed(expand_dims); + +StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axis.empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + int n_new_dim = attrs->axis.size(); + int output_ndim = data_sinfo->ndim + n_new_dim; + std::vector axes = NormalizeAxes(call, ctx, output_ndim, attrs->axis); + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, output_ndim); + } + + std::vector output_shape; + output_shape.resize(output_ndim, PrimExpr()); + for (int i = 0; i < n_new_dim; ++i) { + output_shape[axes[i]] = IntImm(DataType::Int(64), 1); + } + + int i_data_shape = 0; + for (int i = 0; i < output_ndim; ++i) { + if (output_shape[i].defined()) { + continue; + } + ICHECK_LT(i_data_shape, data_sinfo->ndim); + output_shape[i] = data_shape->values[i_data_shape]; + ++i_data_shape; + } + ICHECK_EQ(i_data_shape, data_sinfo->ndim); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.expand_dims") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoExpandDims); + // Helper function for flatten and reshape. PrimExpr ComputeShapeProduct(const Array& shape_values) { PrimExpr shape_prod = IntImm(DataType::Int(64), 1); @@ -41,6 +345,172 @@ PrimExpr ComputeShapeProduct(const Array& shape_values) { return shape_prod; } +/* relax.flatten */ +Expr flatten(Expr x) { + static const Op& op = Op::Get("relax.flatten"); + return Call(op, {std::move(x)}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.flatten").set_body_typed(flatten); + +StructInfo InferStructInfoFlatten(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1); + } else if (data_sinfo->ndim == 0) { + return TensorStructInfo(ShapeExpr({1}), data_sinfo->dtype); + } else if (data_sinfo->ndim == 1) { + return data_sinfo; + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/1); + } + PrimExpr shape_prod = ComputeShapeProduct(data_shape->values); + return TensorStructInfo(ShapeExpr({std::move(shape_prod)}), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.flatten") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoFlatten); + +/* relax.layout_transform */ +TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); + +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value) { + ObjectPtr attrs = make_object(); + attrs->index_map = std::move(index_map); + attrs->pad_value = std::move(pad_value); + + static const Op& op = Op::Get("relax.layout_transform"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); + +StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + tir::IndexMap index_map = attrs->index_map; + Optional optional_pad_value = attrs->pad_value; + + // Check pad_value has same dtype as input. + if (optional_pad_value.defined()) { + PrimExpr padded_value = optional_pad_value.value()->value; + if (padded_value->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "layout_transform pad_value dtype (" << padded_value->dtype + << ") and input dtype (" << data_sinfo->dtype << ") must be the same"); + } + } + + if (data_sinfo->IsUnknownNdim()) { + // Todo(relax-team): revisit here for better check on if the input tensor has desired ndim. + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + // If rank is known, check that it is compatible with the index_map, i.e., #dims match. + if (index_map->initial_indices.size() != static_cast(data_sinfo->ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "number of dimensions in input must match the number of source dimensions " + "in index map, but got " + << data_sinfo->ndim << " != " << index_map->initial_indices.size()); + } + + if (!data_sinfo->shape.defined()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + ShapeStructInfo shape_sinfo = Downcast(data_sinfo->shape.value()->struct_info_); + if (!shape_sinfo->values.defined()) { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + + Array output_shape = index_map->MapShape(shape_sinfo->values.value()); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.layout_transform") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoLayoutTransform); + +/* relax.permute_dims */ +TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs); + +Expr permute_dims(Expr x, Optional> axes) { + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + + static const Op& op = Op::Get("relax.permute_dims"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.permute_dims").set_body_typed(permute_dims); + +bool IsIdentityPermutation(const std::vector& permutation) { + for (int i = 0; i < static_cast(permutation.size()); ++i) { + if (permutation[i] != i) { + return false; + } + } + return true; +} + +StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + + const auto* attrs = call->attrs.as(); + + // Todo(relax-team): revisit here for better check on if the input tensor has + // ndim same as the number of input axes. + if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + if (attrs->axes.defined()) { + int n_axis = attrs->axes.value().size(); + if (!data_sinfo->IsUnknownNdim() && n_axis != data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "PermuteDims expects the number of input axes to equal the ndim of the " + "input tensor. However, the tensor ndim is " + << data_sinfo->ndim << " while the given number of axes is " << n_axis); + } + } + + std::vector axes; + if (attrs->axes.defined()) { + axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes.value()); + } else { + // Construct the reverse permutation via std::iota + axes.resize(data_sinfo->ndim); + std::iota(axes.rbegin(), axes.rend(), 0); + } + if (IsIdentityPermutation(axes)) { + return data_sinfo; + } + + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + std::vector new_shape; + new_shape.reserve(data_sinfo->ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + new_shape.push_back(data_shape->values[axes[i]]); + } + return TensorStructInfo(ShapeExpr(new_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.permute_dims") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoPermuteDims); + /* relax.reshape */ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { if (const auto* e = shape.as()) { @@ -115,18 +585,18 @@ TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 2) { - ctx->ReportFatal(Diagnostic::Error(call->span) << "Reshape op should take 2 arguments"); + ctx->ReportFatal(Diagnostic::Error(call) << "Reshape op should take 2 arguments"); } const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); if (data_sinfo == nullptr) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "Reshape requires the input data to be Tensor. However, the given one is " << call->args[0]->struct_info_->GetTypeKey()); } if (new_shape_sinfo == nullptr) { ctx->ReportFatal( - Diagnostic::Error(call->span) + Diagnostic::Error(call) << "Reshape requires the input new shape to be Shape. However, the given one is " << call->args[1]->struct_info_->GetTypeKey()); } @@ -142,7 +612,7 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_sinfo->values.value()); PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value()); if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "Reshape expects the new shape to be convertible from the old shape. " "However, the old shape is " << data_sinfo->shape << ", with product " << old_shape_prod @@ -159,5 +629,215 @@ TVM_REGISTER_OP("relax.reshape") .add_argument("shape", "Shape", "The input new shape.") .set_attr("FInferStructInfo", InferStructInfoReshape); +/* relax.split */ +TVM_REGISTER_NODE_TYPE(SplitAttrs); + +Expr split(Expr x, ObjectRef indices_or_sections, int axis) { + ObjectPtr attrs = make_object(); + if (const auto* indices = indices_or_sections.as()) { + for (int i = 0; i < static_cast(indices->size()); ++i) { + const auto* idx = indices->at(i).as(); + CHECK(idx != nullptr) << "Split op only accepts an array of integers as the indices. " + "However, the given indices " + << indices_or_sections << " contains some non-integer."; + } + indices_or_sections = ConvertIntImmToInt64(GetRef>(indices)); + } else if (const auto* n_section = indices_or_sections.as()) { + CHECK_GT(n_section->value, 0) << "Split op expects the input number of sections to be a " + "positive integer. However, the given number of sections is " + << n_section->value; + indices_or_sections = IntImm(DataType::Int(64), n_section->value); + } else { + LOG(FATAL) << "Split op expects the input indices_or_sections to be either an Array of " + "PrimExpr or an integer. However, the given one is " + << indices_or_sections->GetTypeKey(); + } + attrs->indices_or_sections = indices_or_sections; + attrs->axis = axis; + + static const Op& op = Op::Get("relax.split"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.split").set_body_typed(split); + +StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + const auto* data_shape = data_sinfo->shape.as(); + int axis = + data_sinfo->IsUnknownNdim() ? -1 : NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + + if (const auto* p_indices = attrs->indices_or_sections.as()) { + // When there is not index, return the input tensor's struct info. + if (p_indices->size() == 0) { + return TupleStructInfo({data_sinfo}); + } + // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. + if (data_shape == nullptr) { + return TupleStructInfo(Array( + p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + + ICHECK_NE(axis, -1); + const auto* axis_length = data_shape->values[axis].as(); + // Fall back to unknown shape when the input tensor shape at the given axis is symbolic. + if (axis_length == nullptr) { + return TupleStructInfo(Array( + p_indices->size() + 1, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + + // Only do output shape inference when all the indices and the total length are integers. + Array indices = GetRef>(p_indices); + IntImm zero(DataType::Int(64), /*value=*/0); + indices.insert(indices.begin(), zero); + indices.insert(indices.end(), Downcast(data_shape->values[axis])); + + std::vector output_sinfo; + output_sinfo.reserve(indices.size() - 1); + for (int i = 0; i + 1 < static_cast(indices.size()); ++i) { + PrimExpr l = tvm::max(zero, indices[i]); + PrimExpr r = tvm::min(data_shape->values[axis], indices[i + 1]); + + Array shape = data_shape->values; + shape.Set(axis, tvm::max(zero, r - l)); + output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + } + return TupleStructInfo(output_sinfo); + } else if (const auto* p_n_section = attrs->indices_or_sections.as()) { + ICHECK_GT(p_n_section->value, 0); + int n_section = p_n_section->value; + // When the number of section is one, return the input tensor's struct info. + if (n_section == 1) { + return TupleStructInfo({data_sinfo}); + } + // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. + if (data_shape == nullptr) { + return TupleStructInfo( + Array(n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim))); + } + ICHECK_NE(axis, -1); + PrimExpr split_len = ceildiv(data_shape->values[axis], n_section); + + // Construct struct info for tensors except the last one. + Array shape = data_shape->values; + shape.Set(axis, split_len); + std::vector output_sinfo(n_section - 1, + TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + + // Construct struct info for the last tensor. + shape.Set(axis, data_shape->values[axis] - split_len * (n_section - 1)); + output_sinfo.push_back(TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype)); + return TupleStructInfo(output_sinfo); + } + ICHECK(false) << "Cannot reach here."; + throw; +} + +TVM_REGISTER_OP("relax.split") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoSplit); + +/* relax.squeeze */ +TVM_REGISTER_NODE_TYPE(SqueezeAttrs); + +Expr squeeze(Expr x, Optional> axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.squeeze"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.squeeze").set_body_typed(squeeze); + +StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axis.defined() && attrs->axis.value().empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + Optional> shape_value; + if (data_sinfo->shape.defined()) { + shape_value = Downcast(data_sinfo->shape.value()->struct_info_)->values; + } + + std::vector axis_removal_mask; + axis_removal_mask.resize(data_sinfo->ndim, /*value=*/false); + + if (attrs->axis.defined()) { + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + + if (!shape_value.defined()) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size()); + } + for (int i = 0; i < static_cast(axes.size()); ++i) { + // Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1. + // When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic + const auto* int_len = shape_value.value()[axes[i]].as(); + if (int_len != nullptr && int_len->value != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Squeeze expects the input tensor shape values at the given axis " + "positions to be all 1. However, the tensor shape at axis " + << axes[i] << " is " << shape_value.value()[axes[i]] + << " which is not 1. If it is symbolic, please use MatchCast to cast it " + "to 1 before doing Squeeze."); + } + axis_removal_mask[axes[i]] = true; + } + } else { + // When `axis` is not defined, squeeze all unit-length dimensions. + // Note: This is a less well-defined path in Array API standard's squeeze + // (https://data-apis.org/array-api/latest/API_specification/generated/array_api.squeeze.html). + // Consider discourage usage later. + if (!shape_value.defined()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + for (int i = 0; i < data_sinfo->ndim; ++i) { + // Whenever a dimension length is symbolic, fall back to unknown ndim. + const auto* int_len = shape_value.value()[i].as(); + if (int_len == nullptr) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + if (int_len->value == 1) { + axis_removal_mask[i] = true; + } + } + } + + std::vector output_shape; + output_shape.reserve(data_sinfo->ndim - axis_removal_mask.size()); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (!axis_removal_mask[i]) { + output_shape.push_back(shape_value.value()[i]); + } + } + + if (data_sinfo->shape.value()->IsInstance()) { + if (static_cast(output_shape.size()) == data_sinfo->ndim) { + return data_sinfo; + } else if (attrs->axis.defined()) { + return TensorStructInfo(data_sinfo->dtype, output_shape.size()); + } else { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + } else { + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); + } +} + +TVM_REGISTER_OP("relax.squeeze") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoSqueeze); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 1a3eb0547d7f..6a2b23ecbdbb 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -24,11 +24,59 @@ #ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_ #define TVM_RELAX_OP_TENSOR_MANIPULATE_H_ +#include + #include "../op_common.h" namespace tvm { namespace relax { +/*! \brief Broadcasts a tensor to a specified shape. */ +Expr broadcast_to(Expr x, Expr shape); + +/*! + * \brief Concatenate the input tensors along the given axis. + * \param tensors An Expr in Tuple type, containing the tensors to be concatenated, + * or a list of tensors + * \param axis The axis along which the tensors are concatenated. + * If it is `NullOpt`, the input tensor is required to be flattened before concatenation. + * \return The concatenated tensor. + */ +Expr concat(Expr tensors, Optional axis); + +/*! + * \brief Insert new axes at the positions given by `axis`. + * \param x The input data to the operator. + * \param axis The axes at which the input array are expanded. + * \return The transformed result. + */ +Expr expand_dims(Expr x, Array axis); + +/*! + * \brief Flatten all the tensor dimensions into one. + * \param x The input data to the operator. + * \return The flattened result. + */ +Expr flatten(Expr x); + +/*! + * \brief Transform layout of a tensor. + * \param x The input data to the operator. + * \param index_map The transformation to apply. + * \param pad_value The value used for padding if the transformation results in implicit padding. If + * not specified, any value can be used. + * \return The transformed result. + */ +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value); + +/*! + * \brief Permutes the dimensions of an array. + * \param x The input data to the operator. + * \param axes The target axes order, reverse order if not specified. + * \return The transposed result. + */ +Expr permute_dims(Expr x, Optional> axes); + /*! * \brief Reshape the input array, supporting `-1` inference in the new * shape when the new shape is given as an Array of PrimExpr. @@ -39,6 +87,31 @@ namespace relax { */ Expr reshape(Expr x, ObjectRef shape); +/*! + * \brief Split input tensor along axis by sections or indices. + * - If indices_or_sections is an integer, the input will be divided equally + * along given axis (if possible). Last section will be smaller if the tensor + * size along the given dimension is not divisible by the integer. + * - If indices_or_sections is a tuple of mixture of int or PrimExpr, + * the entries indicate the indices where along axis the array is split. + * \param x The tensor to be split. + * \param indices_or_sections Indices or sections to split into. + * It is required to be an Array of PrimExpr or an integer. + * \param axis The axis over which to split. + * \return The computed result. + */ +Expr split(Expr x, ObjectRef indices_or_sections, int axis); + +/*! + * \brief Squeeze axes in the array. + * \param x The input data to the operator. + * \param axis The set of axes to remove. + * If it is `NullOpt`, remove all axis of dimensions 1. + * If any specified axis has dimension that does not equal 1, it is an error. + * \return The squeezed result. + */ +Expr squeeze(Expr x, Optional> axis); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py new file mode 100644 index 000000000000..92d4bb26760a --- /dev/null +++ b/tests/python/relax/test_op_manipulate.py @@ -0,0 +1,2373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + assert relax.op.broadcast_to(x, (3, 3, 4, 5)).op == Op.get("relax.broadcast_to") + assert relax.op.concat([x]).op == Op.get("relax.concat") + assert relax.op.expand_dims(x, axis=[]).op == Op.get("relax.expand_dims") + assert relax.op.flatten(x).op == Op.get("relax.flatten") + assert relax.op.permute_dims(x).op == Op.get("relax.permute_dims") + assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") + assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") + assert relax.op.squeeze(x).op == Op.get("relax.squeeze") + assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get( + "relax.layout_transform" + ) + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_reshape_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + s0 = relax.Var("s", R.Shape((3, 8, 5))) + s1 = relax.Var("s", R.Shape(ndim=3)) + s2 = relax.Var("s", R.Shape()) + + _check_inference( + bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x0, (-1,)), relax.TensorStructInfo((120,), "float32")) + _check_inference( + bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x3, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x3, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x4, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference( + bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") + ) + _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo(s0, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo(s0, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo(s0, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.reshape(x3, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s1), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.reshape(x0, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x1, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x2, s2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.reshape(x3, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x4, s2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.reshape(x5, s2), relax.TensorStructInfo(s2, dtype="")) + + +def test_reshape_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + s0 = relax.Var("s", R.Shape((c, a, d, b))) + s1 = relax.Var("s", R.Shape()) + + _check_inference( + bb, relax.op.reshape(x, (c, a, d, b)), relax.TensorStructInfo((c, a, d, b), "float32") + ) + _check_inference( + bb, + relax.op.reshape(x, (d, c, b, -1)), + relax.TensorStructInfo((d, c, b, tir.floordiv(a * b * c * d, d * c * b)), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (1, -1, 1)), + relax.TensorStructInfo((1, a * b * c * d, 1), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (2, -1, a)), + relax.TensorStructInfo((2, tir.floordiv(a * b * c * d, a * 2), a), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, -1, d, b)), + relax.TensorStructInfo((c, tir.floordiv(a * b * c * d, c * d * b), d, b), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, a * d, b)), + relax.TensorStructInfo((c, a * d, b), "float32"), + ) + _check_inference( + bb, + relax.op.reshape(x, (c, a * b * d, -1)), + relax.TensorStructInfo( + (c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))), "float32" + ), + ) + _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) + + +def test_reshape_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + ns0 = relax.Var("ns", relax.ShapeStructInfo((3, 8, 5))) + ns1 = relax.Var("ns", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.reshape(x0, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo(ns0, "float32")) + _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorStructInfo(ns1, "float32")) + _check_inference( + bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo(ns0, "float32")) + _check_inference(bb, relax.op.reshape(x1, ns1), relax.TensorStructInfo(ns1, "float32")) + _check_inference( + bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") + ) + _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo(ns0, "float32")) + _check_inference(bb, relax.op.reshape(x2, ns1), relax.TensorStructInfo(ns1, "float32")) + + +def test_reshape_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, relax.op.reshape(x0, (120,)), relax.TensorStructInfo((120,), "float16")) + _check_inference(bb, relax.op.reshape(x1, (120,)), relax.TensorStructInfo((120,), "int8")) + + +def test_reshape_infer_struct_info_unequal_shape_prod(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((2, 3, 4, 5))) + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + ns = relax.Var("ns", relax.ShapeStructInfo((4, 4, 1, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (4, 4, 1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (4, 4, 1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (4, 4, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (4, 4, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, ns)) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, ns)) + + +def test_reshape_infer_struct_info_inference_not_deducible(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor("float32", ndim=4)) + x1 = relax.Var("x", R.Tensor("float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, (2, 3, -1))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x3, (2, 3, -1))) + + +def test_reshape_new_shape_not_tuple(): + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + relax.op.reshape(x, 120) + with pytest.raises(TVMError): + relax.op.reshape(x, m) + + +def test_reshape_infer_struct_info_new_shape_not_integer(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2.0, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 3, -1.0))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 3, 4.0, -1))) + + +def test_reshape_infer_struct_info_multiple_dim_inference(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, -1, -1, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (-1, -1, -1, -1))) + + +def test_reshape_infer_struct_info_non_positive_new_shape(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (2, 0, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5))) + + +def test_reshape_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + ns = relax.Var("ns", relax.TensorStructInfo((120,), "float32")) + pv = relax.Var("pv", relax.PrimStructInfo("int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x0, (2, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x1, (2, 3, 4, 5))) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, ns)) + with pytest.raises(TVMError): + bb.normalize(relax.op.reshape(x2, [pv])) + + +def test_permute_dims_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((1, 2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + x6 = relax.Var("x", R.Tensor((1,), "float32")) + x7 = relax.Var("x", R.Tensor((), "float32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo((4, 3, 2, 1), "float32") + ) + _check_inference( + bb, + relax.op.permute_dims(x0, [-2, -3, 3, -4]), + relax.TensorStructInfo((3, 2, 4, 1), "float32"), + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + ) + _check_inference( + bb, relax.op.permute_dims(x3, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), dtype="") + ) + _check_inference( + bb, relax.op.permute_dims(x3, axes=None), relax.TensorStructInfo((4, 3, 2, 1), dtype="") + ) + _check_inference( + bb, + relax.op.permute_dims(x3, [-2, -3, 3, -4]), + relax.TensorStructInfo((3, 2, 4, 1), dtype=""), + ) + _check_inference( + bb, relax.op.permute_dims(x4, [2, 3, 1, 0]), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x4, axes=None), relax.TensorStructInfo(dtype="", ndim=4) + ) + _check_inference(bb, relax.op.permute_dims(x5, axes=None), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.permute_dims(x6, axes=None), relax.TensorStructInfo((1,), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x7, axes=None), relax.TensorStructInfo((), "float32") + ) + + +def test_permute_dims_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + + _check_inference( + bb, relax.op.permute_dims(x, [2, 3, 1, 0]), relax.TensorStructInfo((c, d, b, a), "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x, axes=None), relax.TensorStructInfo((d, c, b, a), "float32") + ) + _check_inference( + bb, + relax.op.permute_dims(x, [-2, -3, 3, -4]), + relax.TensorStructInfo((c, b, d, a), "float32"), + ) + + +def test_permute_dims_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [0, 1, 2, 3]), relax.TensorStructInfo(s0, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, [-4, -3, -2, -1]), relax.TensorStructInfo(s0, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x0, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, [0, 1, 2, 3]), relax.TensorStructInfo(s1, "float32") + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 0, 1]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x1, axes=None), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference( + bb, relax.op.permute_dims(x2, axes=None), relax.TensorStructInfo(dtype="float32") + ) + + +def test_permute_dims_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((1, 2, 3, 4), "int32")) + + _check_inference( + bb, relax.op.permute_dims(x0, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "float16") + ) + _check_inference( + bb, relax.op.permute_dims(x1, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int8") + ) + _check_inference( + bb, relax.op.permute_dims(x2, [2, 3, 1, 0]), relax.TensorStructInfo((3, 4, 2, 1), "int32") + ) + + +def test_permute_dims_infer_struct_info_unknown_ndim_with_axes(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor("float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [2, 3, 1, 0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [2, 3, 1, 0])) + + +def test_permute_dims_infer_struct_info_wrong_number_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1, 2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x2, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x2, [1, 2, 4, 0, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x3, [0, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x3, [1, 2, 4, 0, 3])) + + +def test_permute_dims_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 3, 4, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, -5, 1, 3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 3, 4, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, -5, 1, 3])) + + +def test_permute_dims_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0, [0, 2, -2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, 2, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1, [0, 2, -2, 1])) + + +def test_permute_dims_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((1, 2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((1, 2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.permute_dims(x1)) + + +def test_expand_dims_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float32") + ) + _check_inference( + bb, + relax.op.expand_dims(x0, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), "float32"), + ) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference( + bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.expand_dims(x3, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), dtype="") + ) + _check_inference( + bb, + relax.op.expand_dims(x3, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((2, 1, 1, 1, 3, 1, 4, 1), dtype=""), + ) + _check_inference(bb, relax.op.expand_dims(x3, []), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.expand_dims(x4, [1, 3]), relax.TensorStructInfo(dtype="", ndim=5)) + _check_inference(bb, relax.op.expand_dims(x4, []), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.expand_dims(x5, [1, 3]), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.expand_dims(x5, []), relax.TensorStructInfo(dtype="")) + + +def test_expand_dims_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, 4, b), "float32")) + + _check_inference( + bb, relax.op.expand_dims(x, [1, 3]), relax.TensorStructInfo((a, 1, 4, 1, b), "float32") + ) + _check_inference( + bb, + relax.op.expand_dims(x, [-1, 1, -6, 3, 5]), + relax.TensorStructInfo((a, 1, 1, 1, 4, 1, b, 1), "float32"), + ) + _check_inference(bb, relax.op.expand_dims(x, []), relax.TensorStructInfo((a, 4, b), "float32")) + + +def test_expand_dims_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.expand_dims(x0, []), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo(dtype="float32", ndim=5) + ) + _check_inference(bb, relax.op.expand_dims(x1, []), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.expand_dims(x2, []), relax.TensorStructInfo(s2, "float32")) + + +def test_expand_dims_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference( + bb, relax.op.expand_dims(x0, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "float16") + ) + _check_inference( + bb, relax.op.expand_dims(x1, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int8") + ) + _check_inference( + bb, relax.op.expand_dims(x2, [1, 3]), relax.TensorStructInfo((2, 1, 3, 1, 4), "int32") + ) + + +def test_expand_dims_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", relax.TensorStructInfo(s0)) + x3 = relax.Var("x", relax.TensorStructInfo(s1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [-6, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, 5])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [-6, 1])) + + +def test_expand_dims_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", relax.TensorStructInfo(s0)) + x3 = relax.Var("x", relax.TensorStructInfo(s1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x2, [1, -4])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x3, [1, -4])) + + +def test_expand_dims_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x0, axis=[])) + with pytest.raises(TVMError): + bb.normalize(relax.op.expand_dims(x1, axis=[])) + + +def test_layout_transform_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + transpose_transform = lambda a, b, c: (a, c, b) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=transpose_transform), + relax.TensorStructInfo((10, 30, 20), "float32"), + ) + + tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_transform), + relax.TensorStructInfo((10, 10, 30, 2), "float32"), + ) + + implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2), + relax.TensorStructInfo((10, 30, 7, 3), "float32"), + ) + + flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=flatten_transform), + relax.TensorStructInfo((6000,), "float32"), + ) + + +def test_layout_transform_infer_struct_info_mismatch_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "int32")) + + transpose_transform = lambda a, b, c: (a, c, b) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2)) + + +def test_layout_transform_infer_struct_info_unknown_shape(): + bb = relax.BlockBuilder() + tiling_transform = lambda a, b: (a, b // 2, b % 2) + + x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + x_unknown_rank_dtype = relax.Var("x", R.Tensor()) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform), + relax.TensorStructInfo(dtype="", ndim=3), + ) + + +def test_layout_transform_infer_struct_info_symbolic_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + + tiling_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x0, index_map=tiling_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s = relax.Var("s", relax.ShapeStructInfo((30, 20))) + x = relax.Var("x", relax.TensorStructInfo(s, "float32")) + tiling_padding_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_padding_transform), + relax.TensorStructInfo((30, 7, 3), "float32"), + ) + + s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) + x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) + x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_invalid_index_map(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a))) + + +def test_squeeze_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4))) + x4 = relax.Var("x", R.Tensor(ndim=6)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), "float32") + ) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference( + bb, relax.op.squeeze(x1, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1, 4]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo((2, 3, 1, 4), dtype="") + ) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="", ndim=4)) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x5, [1, 4]), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.squeeze(x5), relax.TensorStructInfo(dtype="")) + + +def test_squeeze_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, 1, b), "float32")) + x1 = relax.Var("x", R.Tensor((a, 1, b))) + + _check_inference(bb, relax.op.squeeze(x0, [1]), relax.TensorStructInfo((a, b), "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, [1]), relax.TensorStructInfo((a, b), dtype="")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(dtype="")) + + +def test_squeeze_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s2 = relax.Var("s", relax.ShapeStructInfo((a, 1, b))) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + s4 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + + _check_inference( + bb, relax.op.squeeze(x0, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x0, []), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x1, []), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo(s1, dtype="float32")) + _check_inference(bb, relax.op.squeeze(x2, [1]), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.squeeze(x2, []), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.squeeze(x3, [1, 4]), relax.TensorStructInfo(dtype="float32", ndim=4) + ) + _check_inference(bb, relax.op.squeeze(x3, []), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.squeeze(x3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, [1, 4]), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.squeeze(x4, []), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.squeeze(x4), relax.TensorStructInfo(dtype="float32")) + + +def test_squeeze_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "int32")) + + _check_inference(bb, relax.op.squeeze(x0), relax.TensorStructInfo((2, 3, 4), "float16")) + _check_inference(bb, relax.op.squeeze(x1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, relax.op.squeeze(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + + +def test_squeeze_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [-7])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [6])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [-7])) + + +def test_squeeze_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=6)) + x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=6)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [1, 1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [3, -3])) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x3, [1, 1])) + + +def test_squeeze_infer_struct_info_axis_length_not_one(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((a, 3, 4))) + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((a, 3, 4), "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0, [0])) + _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x2, [0])) + _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2)) + + +def test_squeeze_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.squeeze(x1)) + + +def test_flatten_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor((3,), "float32")) + x2 = relax.Var("x", R.Tensor((), "float32")) + x3 = relax.Var("x", R.Tensor("float32", ndim=3)) + x4 = relax.Var("x", R.Tensor("float32", ndim=1)) + x5 = relax.Var("x", R.Tensor("float32", ndim=0)) + x6 = relax.Var("x", R.Tensor("float32")) + x7 = relax.Var("x", R.Tensor((3, 4, 5))) + x8 = relax.Var("x", R.Tensor((3,))) + x9 = relax.Var("x", R.Tensor(())) + x10 = relax.Var("x", R.Tensor(ndim=3)) + x11 = relax.Var("x", R.Tensor(ndim=1)) + x12 = relax.Var("x", R.Tensor(ndim=0)) + x13 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float32")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((3,), "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x7), relax.TensorStructInfo((60,), dtype="")) + _check_inference(bb, relax.op.flatten(x8), relax.TensorStructInfo((3,), dtype="")) + _check_inference(bb, relax.op.flatten(x9), relax.TensorStructInfo((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x10), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x11), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.flatten(x12), relax.TensorStructInfo((1,), dtype="")) + _check_inference(bb, relax.op.flatten(x13), relax.TensorStructInfo(dtype="", ndim=1)) + + +def test_flatten_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + x1 = relax.Var("x", R.Tensor((a, b))) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((a * b,), "float32")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((a * b,), dtype="")) + + +def test_flatten_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo((3,))) + s2 = relax.Var("s", relax.ShapeStructInfo(())) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + s6 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s6, "float32")) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.flatten(x4), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.flatten(x5), relax.TensorStructInfo((1,), "float32")) + _check_inference(bb, relax.op.flatten(x6), relax.TensorStructInfo(dtype="float32", ndim=1)) + + +def test_flatten_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4, 5), "int32")) + + _check_inference(bb, relax.op.flatten(x0), relax.TensorStructInfo((60,), "float16")) + _check_inference(bb, relax.op.flatten(x1), relax.TensorStructInfo((60,), "int8")) + _check_inference(bb, relax.op.flatten(x2), relax.TensorStructInfo((60,), "int32")) + + +def test_flatten_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.flatten(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.flatten(x1)) + + +def test_flatten_wrong_input_number(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + + with pytest.raises(TypeError): + relax.op.flatten(x, y) + + +def test_concat_infer_struct_info_with_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=3)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((2, 4, 4))) + y4 = relax.Var("y", R.Tensor(ndim=3)) + y5 = relax.Var("y", R.Tensor()) + z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) + z1 = relax.Var("z", R.Tensor("float32", ndim=3)) + z2 = relax.Var("z", R.Tensor("float32")) + z3 = relax.Var("z", R.Tensor((2, 5, 4))) + z4 = relax.Var("z", R.Tensor(ndim=3)) + z5 = relax.Var("z", R.Tensor()) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), "float32") + ) + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), "float32") + ) + _check_inference( + bb, relax.op.concat([x1, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y0, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y1, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y1, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y2, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z0], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y1, z1], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y2, z2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=-1) + ) + _check_inference( + bb, relax.op.concat([x3, y2, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y4, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z2], axis=1), relax.TensorStructInfo(dtype="", ndim=-1) + ) + _check_inference( + bb, relax.op.concat([x3, y3, z3], axis=1), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y3, z3], axis=-2), relax.TensorStructInfo((2, 12, 4), dtype="") + ) + _check_inference( + bb, relax.op.concat([x4, y3, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z3], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y4, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x5, y5, z4], axis=1), relax.TensorStructInfo(dtype="", ndim=3) + ) + _check_inference(bb, relax.op.concat([x5, y5, z5], axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), + relax.TensorStructInfo((2, 12, 4), "float32"), + ) + + +def test_concat_infer_struct_info_with_axis_shape_symbolic(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + b2 = tir.Var("b2", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((a0, b0, c), "float32")) + x1 = relax.Var("x", R.Tensor((a1, b0, c), "float32")) + y = relax.Var("y", R.Tensor((a0, b1, c), "float32")) + z = relax.Var("z", R.Tensor((a0, b2, c), "float32")) + + _check_inference( + bb, + relax.op.concat([x0, y, z], axis=1), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + _check_inference( + bb, + relax.op.concat([x0, y, z], axis=-2), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + _check_inference( + bb, relax.op.concat([x1, y, z], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y, z]), axis=1), + relax.TensorStructInfo((a0, b0 + b1 + b2, c), "float32"), + ) + + +def test_concat_infer_struct_info_with_axis_shape_var(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + b2 = tir.Var("b2", "int64") + c = tir.Var("c", "int64") + sx0 = relax.Var("sx", relax.ShapeStructInfo((2, 3, 4))) + sx1 = relax.Var("sx", relax.ShapeStructInfo((a0, b0, c))) + sx2 = relax.Var("sx", relax.ShapeStructInfo((a1, b0, c))) + sx3 = relax.Var("sx", relax.ShapeStructInfo(ndim=3)) + sx4 = relax.Var("sx", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(sx3, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(sx4, "float32")) + y0 = relax.Var("y", R.Tensor((2, 4, 4), "float32")) + y1 = relax.Var("y", R.Tensor((a0, b1, c), "float32")) + z0 = relax.Var("z", R.Tensor((2, 5, 4), "float32")) + z1 = relax.Var("z", R.Tensor((a0, b2, c), "float32")) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x1, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x2, y1, z1], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x4, y0, z0], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=1), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + + +def test_concat_infer_struct_info_without_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3,), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=1)) + x2 = relax.Var("x", R.Tensor((3,))) + x3 = relax.Var("x", R.Tensor(ndim=1)) + y0 = relax.Var("y", R.Tensor((4,), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=1)) + z0 = relax.Var("z", R.Tensor((5,), "float32")) + z1 = relax.Var("z", R.Tensor("float32", ndim=1)) + + _check_inference( + bb, relax.op.concat([x0, y0, z0], axis=None), relax.TensorStructInfo((12,), "float32") + ) + _check_inference( + bb, + relax.op.concat([x1, y0, z0], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, relax.op.concat([x2, y0, z0], axis=None), relax.TensorStructInfo((12,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x3, y0, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + ) + _check_inference( + bb, + relax.op.concat([x1, y1, z0], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, relax.op.concat([x2, y1, z0], axis=None), relax.TensorStructInfo(dtype="", ndim=1) + ) + _check_inference( + bb, + relax.op.concat([x1, y1, z1], axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0, z0]), axis=None), + relax.TensorStructInfo((12,), "float32"), + ) + + +def test_concat_infer_struct_info_without_axis_shape_symbolic(): + bb = relax.BlockBuilder() + a0 = tir.Var("a0", "int64") + a1 = tir.Var("a1", "int64") + x0 = relax.Var("x", R.Tensor((a0,), "float32")) + x1 = relax.Var("x", R.Tensor((a0,), "")) + y0 = relax.Var("y", R.Tensor((a1,), "float32")) + y1 = relax.Var("y", R.Tensor((a1,), "")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((a0 + a1,), "float32") + ) + _check_inference( + bb, relax.op.concat([x0, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((a0 + a1,), dtype="") + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0]), axis=None), + relax.TensorStructInfo((a0 + a1,), "float32"), + ) + + +def test_concat_infer_struct_info_without_axis_shape_var(): + bb = relax.BlockBuilder() + sx0 = relax.Var("sx", relax.ShapeStructInfo((3,))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=1)) + sy0 = relax.Var("sy", relax.ShapeStructInfo((4,))) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, relax.op.concat([x1, y0], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, + relax.op.concat(relax.Tuple([x0, y0]), axis=None), + relax.TensorStructInfo(dtype="float32", ndim=1), + ) + + +def test_concat_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3,), "float16")) + y0 = relax.Var("y", R.Tensor((4,), "float16")) + x1 = relax.Var("x", R.Tensor((3,), "int8")) + y1 = relax.Var("y", R.Tensor((4,), "int8")) + x2 = relax.Var("x", R.Tensor((3,), "int32")) + y2 = relax.Var("y", R.Tensor((4,), "int32")) + + _check_inference( + bb, relax.op.concat([x0, y0], axis=None), relax.TensorStructInfo((7,), "float16") + ) + _check_inference(bb, relax.op.concat([x1, y1], axis=None), relax.TensorStructInfo((7,), "int8")) + _check_inference( + bb, relax.op.concat([x2, y2], axis=None), relax.TensorStructInfo((7,), "int32") + ) + + +def test_concat_infer_struct_info_tuple_var(): + bb = relax.BlockBuilder() + a = tir.Var("a0", "int64") + b0 = tir.Var("b0", "int64") + b1 = tir.Var("b1", "int64") + t0 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1), "float32")] + ), + ) + t1 = relax.Var( + "t", + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, b0), "float32"), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + t2 = relax.Var( + "t", + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + t3 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + t4 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), "float32"), relax.TensorStructInfo((a, b1))] + ), + ) + t5 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo((a, b0), dtype=""), relax.TensorStructInfo((a, b1), dtype="")] + ), + ) + t6 = relax.Var( + "t", + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="", ndim=2), relax.TensorStructInfo(dtype="")] + ), + ) + t7 = relax.Var( + "t", + relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), relax.TensorStructInfo(dtype="")]), + ) + + _check_inference( + bb, relax.op.concat(t0, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") + ) + _check_inference( + bb, relax.op.concat(t1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.concat(t2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.concat(t3, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.concat(t4, axis=1), relax.TensorStructInfo((a, b0 + b1), "float32") + ) + _check_inference( + bb, relax.op.concat(t5, axis=1), relax.TensorStructInfo((a, b0 + b1), dtype="") + ) + _check_inference(bb, relax.op.concat(t6, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.concat(t7, axis=1), relax.TensorStructInfo(dtype="")) + + +def test_concat_infer_struct_info_single_input_tensor(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((3, a))) + s1 = relax.Var("s", relax.ShapeStructInfo((a,))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, a), "float32")) + x1 = relax.Var("x", R.Tensor((a,), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=3)) + x3 = relax.Var("x", R.Tensor("float32", ndim=1)) + x4 = relax.Var("x", R.Tensor("float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x7 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x8 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + x9 = relax.Var("x", relax.TensorStructInfo(s4, "float32")) + + _check_inference(bb, relax.op.concat([x0], axis=1), relax.TensorStructInfo((3, a), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=0), relax.TensorStructInfo((a,), "float32")) + _check_inference(bb, relax.op.concat([x1], axis=None), relax.TensorStructInfo((a,), "float32")) + _check_inference( + bb, relax.op.concat([x2], axis=1), relax.TensorStructInfo(dtype="float32", ndim=3) + ) + _check_inference( + bb, relax.op.concat([x3], axis=0), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference( + bb, relax.op.concat([x3], axis=None), relax.TensorStructInfo(dtype="float32", ndim=1) + ) + _check_inference(bb, relax.op.concat([x4], axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.concat([x5], axis=1), relax.TensorStructInfo(s0, dtype="float32")) + _check_inference(bb, relax.op.concat([x6], axis=0), relax.TensorStructInfo(s1, dtype="float32")) + _check_inference( + bb, relax.op.concat([x6], axis=None), relax.TensorStructInfo(s1, dtype="float32") + ) + _check_inference(bb, relax.op.concat([x7], axis=1), relax.TensorStructInfo(s2, dtype="float32")) + _check_inference(bb, relax.op.concat([x8], axis=0), relax.TensorStructInfo(s3, dtype="float32")) + _check_inference( + bb, relax.op.concat([x8], axis=None), relax.TensorStructInfo(s3, dtype="float32") + ) + _check_inference(bb, relax.op.concat([x9], axis=1), relax.TensorStructInfo(s4, dtype="float32")) + + +def test_concat_infer_struct_info_zero_rank_input_tensor(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=0)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=None)) + + +def test_concat_infer_struct_info_no_input_tensor(): + bb = relax.BlockBuilder() + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([], axis=None)) + + +def test_concat_infer_struct_info_without_axis_but_tensor_not_one_dimensional(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=None)) + _check_inference(bb, relax.op.concat([x2], axis=None), relax.TensorStructInfo(dtype="float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=None)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x4], axis=None)) + _check_inference(bb, relax.op.concat([x5], axis=None), relax.TensorStructInfo(s2, "float32")) + + +def test_concat_infer_struct_info_inconsistent_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3,))) + y = relax.Var("y", R.Tensor((4,), "float32")) + z = relax.Var("z", R.Tensor((5,), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y, z], axis=0)) + + +def test_concat_infer_struct_info_inconsistent_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x = relax.Var("x", R.Tensor((3,), "float32")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + z = relax.Var("z", R.Tensor((5,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y0, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y1, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y2, z], axis=0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x, y3, z], axis=0)) + + +def test_concat_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((3,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", R.Tensor((3,), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=1)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3], axis=1)) + + +def test_concat_infer_struct_info_unequal_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo((3, a + 2))) + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((3, a + 2), "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + y0 = relax.Var("y", R.Tensor((3, 3), "float32")) + y1 = relax.Var("y", R.Tensor((3, a), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x0, y0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x2, y0])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x1, y1])) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([x3, y1])) + + +def test_concat_infer_struct_info_input_not_tuple(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((3,), "float32")) + s = relax.Var("s", relax.ShapeStructInfo((3,))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat(x)) + with pytest.raises(TVMError): + bb.normalize(relax.op.concat(s)) + + +def test_concat_infer_struct_info_input_tuple_field_not_tensor(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((3,))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.concat([s])) + + +def test_split_infer_struct_info_by_indices(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 3, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 3, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x3, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), dtype=""), + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 3, 4), dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x4, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x5, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, [-2, 2, 6, 4, 8, 12, 9], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + relax.TensorStructInfo((2, 0, 4), "float32"), + relax.TensorStructInfo((2, 1, 4), "float32"), + ] + ), + ) + + +def test_split_infer_struct_info_by_indices_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, b), "float32")) + + _check_inference( + bb, + relax.op.split(x, [10, 20], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorStructInfo(dtype="float32", ndim=2), + relax.TensorStructInfo(dtype="float32", ndim=2), + ] + ), + ) + + +def test_split_infer_struct_info_by_indices_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, [3], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, [3], axis=1), + relax.TupleStructInfo( + [relax.TensorStructInfo(dtype="float32"), relax.TensorStructInfo(dtype="float32")] + ), + ) + + +def test_split_infer_struct_info_by_n_section(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 10, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 2, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 5, 4), "float32"), + relax.TensorStructInfo((2, 5, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 3, axis=-2), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 4, 4), "float32"), + relax.TensorStructInfo((2, 2, 4), "float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x3, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 4, 4), dtype=""), + relax.TensorStructInfo((2, 2, 4), dtype=""), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x4, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + relax.TensorStructInfo(dtype="", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x5, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + relax.TensorStructInfo(dtype=""), + ] + ), + ) + + +def test_split_infer_struct_info_by_n_section_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x = relax.Var("x", R.Tensor((a, b), "float32")) + + _check_inference( + bb, + relax.op.split(x, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), + relax.TensorStructInfo((a, tir.ceildiv(b, 3)), "float32"), + relax.TensorStructInfo((a, b - tir.ceildiv(b, 3) * 2), "float32"), + ] + ), + ) + + +def test_split_infer_struct_info_by_n_section_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 10, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="float32", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x2, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="float32"), + ] + ), + ) + + +def test_split_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 10, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 10, 4), "int8")) + + _check_inference( + bb, + relax.op.split(x0, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "float16"), + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 3, 4), "float16"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, [3, 7], axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 3, 4), "int8"), + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 3, 4), "int8"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x0, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 4, 4), "float16"), + relax.TensorStructInfo((2, 2, 4), "float16"), + ] + ), + ) + _check_inference( + bb, + relax.op.split(x1, 3, axis=1), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 4, 4), "int8"), + relax.TensorStructInfo((2, 2, 4), "int8"), + ] + ), + ) + + +def test_split_infer_struct_info_single_output(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((a, b))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.split(x0, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + ) + _check_inference( + bb, + relax.op.split(x1, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + ) + _check_inference( + bb, + relax.op.split(x2, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + ) + _check_inference( + bb, + relax.op.split(x3, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x4, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x5, [], axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x0, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + ) + _check_inference( + bb, + relax.op.split(x1, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + ) + _check_inference( + bb, + relax.op.split(x2, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + ) + _check_inference( + bb, + relax.op.split(x3, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x4, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + ) + _check_inference( + bb, + relax.op.split(x5, 1, axis=1), + relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + ) + + +def test_split_indices_or_sections_int64(): + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + split0 = relax.op.split(x, [3, 6], axis=1) + split1 = relax.op.split(x, 4, axis=1) + + assert split0.attrs.indices_or_sections[0].dtype == "int64" + assert split0.attrs.indices_or_sections[1].dtype == "int64" + assert split1.attrs.indices_or_sections.dtype == "int64" + + +def test_split_infer_struct_info_non_integer_indices(): + bb = relax.BlockBuilder() + a = tir.Var("c", "int64") + b = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((3, 4), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x, [a, b], axis=1)) + + +def test_split_invalid_n_section(): + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((3, 4), "float32")) + + with pytest.raises(TVMError): + relax.op.split(x, 0, axis=1) + with pytest.raises(TVMError): + relax.op.split(x, -1, axis=1) + with pytest.raises(TVMError): + relax.op.split(x, n, axis=1) + + +def test_split_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [], axis=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [], axis=-3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=-3)) + + +def test_split_infer_invalid_struct_info_indices(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + v = relax.Var("v", relax.PrimStructInfo("int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, [v], axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, v, axis=1)) + + +def test_split_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x0, 1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.split(x1, 1, axis=1)) + + +def test_broadcast_to_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 3))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x3, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + _check_inference( + bb, relax.op.broadcast_to(x4, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + _check_inference( + bb, relax.op.broadcast_to(x5, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), dtype="") + ) + + +def test_broadcast_to_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) + x1 = relax.Var("x", R.Tensor((b, 1, 1, d))) + + _check_inference( + bb, + relax.op.broadcast_to(x0, (a, b, 1, c, d)), + relax.TensorStructInfo((a, b, 1, c, d), "float32"), + ) + _check_inference( + bb, + relax.op.broadcast_to(x1, (a, b, 1, c, d)), + relax.TensorStructInfo((a, b, 1, c, d), dtype=""), + ) + + +def test_broadcast_to_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float32") + ) + + +def test_broadcast_to_infer_struct_info_tgt_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((b, 1, 1, d))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b, 1, c, d))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=5)) + stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) + + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x2, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference(bb, relax.op.broadcast_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32")) + + +def test_broadcast_to_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 1, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 1, 3), "int32")) + + _check_inference( + bb, relax.op.broadcast_to(x0, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "float16") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int8") + ) + _check_inference( + bb, relax.op.broadcast_to(x2, (4, 2, 5, 3)), relax.TensorStructInfo((4, 2, 5, 3), "int32") + ) + + +def test_broadcast_to_infer_struct_info_tgt_ndim_less_than_old_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 1))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x0 = relax.Var("x", R.Tensor((2, 1), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2,))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x2, stgt1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, (2,))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, stgt0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x3, stgt1)) + + +def test_broadcast_to_infer_struct_info_not_broadcastable_static(): + bb = relax.BlockBuilder() + s = relax.Var("s", relax.ShapeStructInfo((2, 1, 3))) + x0 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + stgt = relax.Var("stgt", relax.ShapeStructInfo((2, 1, 6))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (2, 1, 6))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, stgt)) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, (2, 1, 6))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt)) + + +def test_broadcast_to_infer_struct_info_not_broadcastable_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s = relax.Var("s", relax.ShapeStructInfo((2, a))) + x0 = relax.Var("x", R.Tensor((2, a), "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((2, b))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo((2, 1))) + stgt2 = relax.Var("stgt", relax.ShapeStructInfo((b, a))) + + _check_inference( + bb, relax.op.broadcast_to(x0, (2, b)), relax.TensorStructInfo((2, b), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x0, (2, 1)), relax.TensorStructInfo((2, 1), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x0, (b, a)), relax.TensorStructInfo((b, a), "float32") + ) + _check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32")) + _check_inference( + bb, relax.op.broadcast_to(x1, (2, b)), relax.TensorStructInfo((2, b), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (2, 1)), relax.TensorStructInfo((2, 1), "float32") + ) + _check_inference( + bb, relax.op.broadcast_to(x1, (b, a)), relax.TensorStructInfo((b, a), "float32") + ) + _check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32")) + _check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32")) + + +def test_broadcast_to_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 1, 3))) + x1 = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + stgt = relax.Var("stgt", relax.TensorStructInfo((4, 2, 5, 3), dtype="")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x0, (4, 2, 5, 3))) + with pytest.raises(TVMError): + bb.normalize(relax.op.broadcast_to(x1, stgt)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py new file mode 100644 index 000000000000..27f089ee67c1 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_broadcast_to(): + @R.function + def foo(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32"): + gv: R.Tensor((4, 2, 5, 3), "float32") = R.broadcast_to(x, (4, 2, 5, 3)) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 1, 3), "float32")) + with bb.function("foo", [x]): + gv = bb.emit(relax.op.broadcast_to(x, (4, 2, 5, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_concat(): + @R.function + def foo( + x1: R.Tensor((1, 2, 3), "float32"), + x2: R.Tensor((1, 3, 3), "float32"), + x3: R.Tensor((1, 4, 3), "float32"), + ) -> R.Tensor((1, 9, 3), "float32"): + gv: R.Tensor((1, 9, 3), "float32") = R.concat((x1, x2, x3), axis=1) + return gv + + x1 = relax.Var("x1", R.Tensor((1, 2, 3), "float32")) + x2 = relax.Var("x2", R.Tensor((1, 3, 3), "float32")) + x3 = relax.Var("x3", R.Tensor((1, 4, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x1, x2, x3]): + gv = bb.emit(relax.op.concat((x1, x2, x3), axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_concat_without_specified_axis(): + @R.function + def foo( + x1: R.Tensor((2,), "float32"), x2: R.Tensor((3,), "float32"), x3: R.Tensor((4,), "float32") + ) -> R.Tensor((9,), "float32"): + gv: R.Tensor((9,), "float32") = R.concat((x1, x2, x3), axis=None) + return gv + + x1 = relax.Var("x1", R.Tensor((2,), "float32")) + x2 = relax.Var("x2", R.Tensor((3,), "float32")) + x3 = relax.Var("x3", R.Tensor((4,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x1, x2, x3]): + gv = bb.emit(relax.op.concat((x1, x2, x3), axis=None)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_expand_dims(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32"): + gv: R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32") = R.expand_dims(x, axis=[-1, 1, -6, 3, 5]) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.expand_dims(x, axis=[-1, 1, -6, 3, 5])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_flatten(): + @R.function + def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((60,), "float32"): + gv: R.Tensor((60,), "float32") = R.flatten(x) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.flatten(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform(): + transformation = lambda n, c, h, w: (n, h, w, c) + + @R.function + def foo(x: R.Tensor((2, 3, 4, 5), "float32")): + gv: R.Tensor((2, 4, 5, 3), "float32") = R.layout_transform(x, index_map=transformation) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform_with_padding(): + transformation = lambda n, c, h, w: (n, c // 3, h, w, c % 3) + + @R.function + def foo(x: R.Tensor((10, 20, 2, 2), "float32")): + gv: R.Tensor((10, 7, 2, 2, 3), "float32") = R.layout_transform( + x, index_map=transformation, pad_value=2 + ) + return gv + + x = relax.Var("x", R.Tensor((10, 20, 2, 2), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation, pad_value=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_permute_dims(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): + gv: R.Tensor((2, 4, 3, 1), "float32") = R.permute_dims(x, axes=[1, -1, 2, -4]) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.permute_dims(x, axes=[1, -1, 2, -4])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_permute_dims_none_arg(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((4, 3, 2, 1), "float32"): + gv: R.Tensor((4, 3, 2, 1), "float32") = R.permute_dims(x) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.permute_dims(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_reshape(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + gv: R.Tensor((8, 3), "float32") = R.reshape(x, (8, 3)) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.reshape(x, shape=(8, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_reshape_infer_dim(): + @R.function + def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 1, 3), "float32"): + gv: R.Tensor((8, 1, 3), "float32") = R.reshape(x, (8, -1, 3)) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.reshape(x, shape=(8, -1, 3))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_split_by_indices(): + @R.function + def foo( + x: R.Tensor((2, 10, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 1, 4), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 4, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 0, 4), dtype="float32"), + R.Tensor((2, 1, 4), dtype="float32"), + ) = R.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_split_by_n_section(): + @R.function + def foo( + x: R.Tensor((2, 10, 4), dtype="float32") + ) -> R.Tuple( + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + ): + gv: R.Tuple( + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + R.Tensor((2, 2, 4), dtype="float32"), + ) = R.split(x, indices_or_sections=5, axis=1) + return gv + + x = relax.Var("x", R.Tensor((2, 10, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.split(x, indices_or_sections=5, axis=1)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_squeeze(): + @R.function + def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.squeeze(x) + return gv + + x = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.squeeze(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_squeeze_with_indices(): + @R.function + def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "float32"): + gv: R.Tensor((2, 3, 1, 4), "float32") = R.squeeze(x, axis=[3, -5]) + return gv + + x = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.squeeze(x, axis=[3, -5])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()