From ed4fd502038af83d34c23b83c1760a6ca04e637d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 27 Apr 2024 15:15:34 -0500 Subject: [PATCH 1/3] [Relax] Implement relax.op.view This commit implements `relax.op.view` (`R.view` in TVMScript) to produce a view into an existing array. This returned view shares the same backing allocation as the existing array. Because `R.view` comes with potential trade-offs; such as increased memory footprint, performance cost to apply a non-zero `DLTensor::byte_offset`, and potential misalignment for vector operators; this PR does not use `R.view` apart from unit tests. Applications of `R.view`, either for specific compute kernels or in optimization passes, is instead kept for follow-up PRs. --- python/tvm/relax/expr.py | 15 +- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/view.py | 76 +++ python/tvm/relax/struct_info.py | 7 +- python/tvm/script/ir_builder/relax/ir.py | 2 + python/tvm/script/parser/relax/entry.py | 14 +- src/relax/ir/expr.cc | 11 +- src/relax/op/tensor/view.cc | 359 +++++++++++ src/relax/op/tensor/view.h | 38 ++ tests/python/relax/test_op_view.py | 776 +++++++++++++++++++++++ 10 files changed, 1287 insertions(+), 12 deletions(-) create mode 100644 python/tvm/relax/op/view.py create mode 100644 src/relax/op/tensor/view.cc create mode 100644 src/relax/op/tensor/view.h create mode 100644 tests/python/relax/test_op_view.py diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 4dca710e7781..522eb11d6df7 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1108,21 +1108,26 @@ def inline_functions( @tvm._ffi.register_object("relax.expr.ExternFunc") -class ExternFunc(BaseFunc): +class ExternFunc(BaseFunc, ExprWithOp): """extern function, which represents a PackedFunc.""" global_symbol: String span: Optional[Span] - def __init__(self, global_symbol: String, span: Optional[Span] = None) -> None: + def __init__( + self, + global_symbol: String, + struct_info: Optional[StructInfo] = None, + span: Optional[Span] = None, + ) -> None: self.__init_handle_by_constructor__( - _ffi_api.ExternFunc, global_symbol, span # type: ignore + _ffi_api.ExternFunc, global_symbol, struct_info, span # type: ignore ) -def extern(name: str, span: Optional[Span] = None): +def extern(name: str, struct_info: Optional[StructInfo] = None, span: Optional[Span] = None): """Create extern function.""" - return ExternFunc(name, span) + return ExternFunc(name, struct_info, span) def const( diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 5b585e18b450..760af936da8c 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -135,6 +135,7 @@ tan, tanh, ) +from .view import view def _register_op_make(): diff --git a/python/tvm/relax/op/view.py b/python/tvm/relax/op/view.py new file mode 100644 index 000000000000..65102baad141 --- /dev/null +++ b/python/tvm/relax/op/view.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Operations that act on the DLTensor container """ +from typing import Optional, Sequence, Union + +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr, PrimValue, ShapeExpr, DataTypeImm + +PrimExprLike = Union[int, PrimExpr] + + +def view( + data: Expr, + shape: Optional[Union[Sequence[PrimExprLike], Expr]] = None, + dtype: Optional[Expr] = None, + relative_byte_offset: Optional[Expr] = None, +) -> Expr: + """Broadcasts a tensor to a specified shape. + + Parameters + ---------- + data : relax.Expr + + The input data to the operator. + + shape : Optional[Union[Sequence[PrimExprLike], Expr]] + + The target shape. Should be a `relax.ShapeExpr`, or a + collection that can be converted to a `relax.ShapeExpr`. + + dtype : Optional[Expr] + + The target datatype. Should be a `relax.ShapeExpr`, or a + collection that can be converted to a `relax.ShapeExpr`. + + relative_byte_offset: Optional[Expr] + + The offset of the output NDArray, relative to the byte offset + of `data`. If `None`, the offset of the view is the same as + the offset of `data`. + + Returns + ------- + result : relax.Expr + The tensor view + + """ + + def _normalize(expr, relax_cls): + if expr is None or isinstance(expr, Expr): + return expr + else: + return relax_cls(expr) + + shape = _normalize(shape, ShapeExpr) + dtype = _normalize(dtype, DataTypeImm) + relative_byte_offset = _normalize(relative_byte_offset, PrimValue) + + return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index 34a9d82595d1..de1b1ac3bfc3 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -233,7 +233,7 @@ def __init__( def opaque_func( *, ret: Optional[StructInfo] = None, - derive_func: Optional[EnvFunc] = None, + derive_func: Optional[Union[str, EnvFunc]] = None, purity: bool = False, span: Span = None, ) -> "FuncStructInfo": @@ -249,7 +249,7 @@ def opaque_func( ret: Optional[StructInfo] The struct info of the function return value. - derive_func: Optional[EnvFunc] + derive_func: Optional[Union[str,EnvFunc]] The environment function used for derivation purity: bool @@ -266,4 +266,7 @@ def opaque_func( ---- We cannot specify ret and derive_func simultaneously. """ + + if isinstance(derive_func, str): + derive_func = tvm.ir.EnvFunc.get("tvm.relax.struct_info.infer_view_sinfo") return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, purity, span) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 6dbf5c5dfdb4..84032ec72c65 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -133,6 +133,7 @@ sum, take, variance, + view, sigmoid, sign, sin, @@ -794,6 +795,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tuple", "unique", "variance", + "view", "vm", "vpi", "vulkan", diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index a3b391637cb4..7ede3d14fda7 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -20,6 +20,7 @@ from typing import Callable as _Callable from typing import Dict, List, Optional, Set, TypeVar, Union +import tvm from tvm.relax import ( Expr, SeqExpr, @@ -277,6 +278,7 @@ class CallableProxy(StructInfoProxy): params: List[StructInfoProxy] ret: StructInfoProxy purity: bool + derive_func: Optional[Union[str, tvm.ir.EnvFunc]] """Function type. @@ -296,6 +298,9 @@ class CallableProxy(StructInfoProxy): purity : bool Whether the callable is pure. + derive_func: Optional[Union[str, tvm.ir.EnvFunc]] + The derivation function for the outputq + """ def __init__( @@ -303,6 +308,7 @@ def __init__( params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None, ret: Optional[StructInfoProxy] = None, purity: Optional[bool] = None, + derive_func: Optional[Union[str, tvm.ir.EnvFunc]] = None, ) -> None: if params is None: self.params = params @@ -320,6 +326,7 @@ def __init__( self.ret = ret() if callable(ret) else ret self.purity = purity + self.derive_func = derive_func def get_symbolic_vars(self) -> Set[str]: if self.params is None: @@ -339,7 +346,9 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncS params = [param.as_struct_info(dict_globals) for param in self.params] if params is None: - return FuncStructInfo.opaque_func(ret=ret, purity=self.purity) + return FuncStructInfo.opaque_func( + ret=ret, derive_func=self.derive_func, purity=self.purity + ) else: return FuncStructInfo(params, ret, purity=self.purity) @@ -348,8 +357,9 @@ def Callable( params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None, ret: Optional[StructInfoProxy] = None, purity: Optional[bool] = None, + derive_func: Optional[Union[str, tvm.ir.EnvFunc]] = None, ) -> CallableProxy: - return CallableProxy(params, ret, purity=purity) + return CallableProxy(params, ret, purity=purity, derive_func=derive_func) ############################### R.Tuple ################################ diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index eb467757653b..59b6a0aeb78b 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -650,9 +650,14 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span) data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) { - return ExternFunc(global_symbol, span); -}); +TVM_REGISTER_GLOBAL("relax.ExternFunc") + .set_body_typed([](String global_symbol, Optional struct_info, Span span) { + if (struct_info.defined()) { + return ExternFunc(global_symbol, struct_info.value(), span); + } else { + return ExternFunc(global_symbol, span); + } + }); Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. diff --git a/src/relax/op/tensor/view.cc b/src/relax/op/tensor/view.cc new file mode 100644 index 000000000000..78a6ae4aa5b1 --- /dev/null +++ b/src/relax/op/tensor/view.cc @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file view.cc + * \brief Manipulation operators. + */ + +#include "view.h" + +namespace tvm { +namespace relax { + +/* relax.broadcast_to */ +Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset) { + Tuple void_expr(Array{}); + + static const Op& op = Op::Get("relax.view"); + return Call(op, { + x, + shape.value_or(void_expr), + dtype.value_or(void_expr), + relative_byte_offset.value_or(void_expr), + }); +} + +TVM_REGISTER_GLOBAL("relax.op.view").set_body_typed(view); + +StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 4) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << call->op << " should receive 4 arguments, " + << "but received " << call->args); + } + Expr arg_data = call->args[0]; + Expr arg_shape = call->args[1]; + Expr arg_dtype = call->args[2]; + Expr arg_relative_byte_offset = call->args[3]; + + TensorStructInfo data_sinfo = [&]() -> TensorStructInfo { + StructInfo sinfo = GetStructInfo(arg_data); + if (auto opt = sinfo.as()) { + return opt.value(); + } else { + LOG(FATAL) << "TypeError: " + << "Operator " << call->op << " expects first argument to be a tensor, " + << "but received " << arg_data << " with type " << sinfo; + } + }(); + auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* { + StructInfo sinfo = GetStructInfo(arg_shape); + if (HasVoidStructInfo(arg_shape)) { + // No shape change is applied. The input tensor's shape is + // kept as-is. + return nullptr; + } else if (auto ptr = sinfo.as()) { + // The `R.view` operation returns a different shape. + return ptr; + } else { + LOG(FATAL) << "TypeError: " + << "Operator " << call->op << " expects second argument to be a ShapeExpr, " + << "or a void-type (empty relax tuple), " + << "but received " << arg_shape << " with type " << sinfo; + } + }(); + + auto view_dtype = [&]() -> std::optional { + StructInfo sinfo = GetStructInfo(arg_dtype); + + if (HasVoidStructInfo(arg_dtype)) { + // No datatype change is applied. The input tensor's dtype is + // kept as-is. + return std::nullopt; + } + + Expr arg_value = arg_dtype; + while (auto arg_var = arg_value.as()) { + if (auto bound_value = ctx->LookupBinding(arg_var.value())) { + arg_value = bound_value.value(); + } else { + break; + } + } + + // In general, StructInfo inference should only depend on the + // StructInfo of the arguments, and not on the arguments + // themselves. However, `relax::DataTypeImm` uses + // `ObjectStructInfo`, so we need to inspect the argument itself + // in this case. + if (auto dtype_imm = arg_value.as()) { + // We know the datatype for the view. + return dtype_imm->value; + } else if (sinfo.as()) { + // The view changes the datatype, but we don't know what it is + // being changed into. + return DataType::Void(); + } else { + LOG(FATAL) << "TypeError: " + << "Operator " << call->op + << " expects the dtype argument to be a relax::DataTypeImm, " + << "but received " << arg_dtype << " with type " << sinfo; + } + }(); + + auto view_relative_byte_offset = [&]() -> Optional { + StructInfo sinfo = GetStructInfo(arg_relative_byte_offset); + + if (HasVoidStructInfo(arg_relative_byte_offset)) { + // No byte offset is specified, so no change is applied. + return IntImm(DataType::Int(64), 0); + } else if (auto prim_sinfo = sinfo.as()) { + CHECK_EQ(prim_sinfo->dtype, DataType::Int(64)) + << "TypeError: " + << "Operator " << call->op + << " expects the relative_byte_offset to be a 64-bit integer, but received " + << arg_relative_byte_offset << ", which has type " << sinfo; + if (prim_sinfo->value.defined()) { + // An offset of known value is applied. The known value may + // be dynamic. + return prim_sinfo->value.value(); + } else { + // An offset of unknown value is applied. + return NullOpt; + } + } else { + LOG(FATAL) << "TypeError: " + << "Operator " << call->op << " expects the relative_byte_offset argument " + << "to be a Relax PrimValue. " + << "However, expression " << call << " provides relative_byte_offset of " + << arg_relative_byte_offset << ", which has type " << sinfo; + } + }(); + + Optional> input_shape = data_sinfo->GetShape(); + + Optional> output_shape = NullOpt; + int output_ndim = kUnknownNDim; + if (view_shape_sinfo && view_shape_sinfo->values.defined()) { + output_shape = view_shape_sinfo->values.value(); + } else if (view_shape_sinfo) { + output_ndim = view_shape_sinfo->ndim; + } else if (input_shape) { + output_shape = input_shape; + } else { + output_ndim = data_sinfo->ndim; + } + + DataType output_dtype = view_dtype.value_or(data_sinfo->dtype); + + // Helper function, returns the number of bytes per vectorized + // element. Cannot use `DataType::bytes`, as it returns the + // number of bytes per scalar element. + auto get_size_bytes = [](const DataType& dtype) -> Optional { + if (dtype.is_void()) { + return NullOpt; + } else { + auto size_bits = dtype.bits() * dtype.lanes(); + return IntImm(DataType::Int(64), (size_bits + 7) / 8); + } + }; + + // Helper function, returns the number of elements in an array, + // given the shape of that array. + auto get_num_elements = [&ctx](const Optional>& shape) -> Optional { + if (!shape.defined()) { + return NullOpt; + } + + PrimExpr num_elements = Integer(1); + for (const auto& dim : shape.value()) { + num_elements *= dim; + } + return ctx->GetAnalyzer()->Simplify(num_elements); + }; + + Optional input_nelements = get_num_elements(input_shape); + Optional output_nelements = get_num_elements(output_shape); + + Optional input_element_size = get_size_bytes(data_sinfo->dtype); + Optional output_element_size = get_size_bytes(output_dtype); + + if (input_nelements && output_nelements && input_element_size && output_element_size && + view_relative_byte_offset) { + // The shapes and dtype of input and output are known. We know + // the byte_offset that is applied, and can verify that the view + // does not overrun the bounds of the original array. + + PrimExpr input_nbytes = input_nelements.value() * input_element_size.value(); + PrimExpr output_nbytes = output_nelements.value() * output_element_size.value(); + PrimExpr view_end = output_nbytes + view_relative_byte_offset.value(); + + if (ctx->GetAnalyzer()->CanProve(output_nbytes + view_relative_byte_offset.value() > + input_nbytes)) { + LOG(FATAL) << "ValueError: " + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to create view of type " + << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) + << " with relative byte offset " << view_relative_byte_offset + << ", viewing into the array " << arg_data << " of type " << data_sinfo << ". " + << "The end of the view would occur at byte " << view_end + << ", relative to the start of array " << arg_data << ", but " << arg_data + << " is only " << input_nbytes << " long."; + } + + } else if (input_nelements && output_nelements && input_element_size && output_element_size) { + // The shapes and dtype of input and output are known. However, + // we don't know if the `byte_offset` is being adjusted. We can + // still check validate using the size of the view. If the view + // is larger than the original array, then it would overrun its + // bounds regardless of the `relative_byte_offset` being applied. + + PrimExpr input_nbytes = input_nelements.value() * input_element_size.value(); + PrimExpr output_nbytes = output_nelements.value() * output_element_size.value(); + + if (ctx->GetAnalyzer()->CanProve(output_nbytes > input_nbytes)) { + LOG(FATAL) << "ValueError: " + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to create view of type " + << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) + << " from input array of type " << data_sinfo << ". " + << "This view would increase the size from " << output_nbytes << " bytes to " + << output_nbytes << " bytes."; + } + + } else if (input_element_size && output_element_size && !view_shape_sinfo) { + // The output view has a known dtype, which is different from the + // known dtype of the input array. Because the view's shape is + // the same as the original array, when counted in number of + // elements, an increase to the per-element size would cause the + // view to be larger than the original array. + + CHECK_GE(input_element_size.value()->value, output_element_size.value()->value) + << "ValueError: " + << "Operator " << call->op + << " may not produce a view that exceeds the bounds of the original array. " + << "In expression " << call << " the data type is changed from " << data_sinfo->dtype + << " to " << view_dtype.value() << ", increasing the size per element from " + << input_element_size << " bytes to " << output_element_size << " bytes. " + << "Consider providing a new shape for the R.view."; + } else if (input_nelements && output_nelements && !view_dtype) { + // The shape is being updated, while keeping the datatype the + // same. Even though we don't know the size of each element, we + // know it must be the same for the input and output arrays. An + // increase to the number of elements would cause the view to be + // larger than the original array, regardless of the size of each + // individual element. + + if (ctx->GetAnalyzer()->CanProve(output_nelements.value() > input_nelements.value())) { + LOG(FATAL) << "ValueError: " + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to view array " << arg_data + << " (shape = " << input_shape << ", " << input_nelements << " elements) as shape " + << output_shape << " with " << output_nelements << " elements."; + } + } else if (view_relative_byte_offset && !view_shape_sinfo && !view_dtype) { + // The byte_offset is being updated, but neither the shape nor the + // dtype is changing. Any non-zero offset will cause the view to + // overrun the bounds of the original array. + if (ctx->GetAnalyzer()->CanProve(view_relative_byte_offset.value() > 0)) { + LOG(FATAL) << "ValueError: " + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to offset the view by " + << view_relative_byte_offset << " bytes, " + << "without reducing either the number of elements in the view " + << "or the size of each element."; + } + } + + if (output_shape.defined()) { + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, data_sinfo->vdevice); + } else { + return TensorStructInfo(output_dtype, output_ndim, data_sinfo->vdevice); + } +} + +TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView); + +Expr LegalizeView(const BlockBuilder& bb, const Call& call) { + Expr data = call->args[0]; + Expr shape = call->args[1]; + Expr dtype = call->args[2]; + Expr relative_byte_offset = call->args[3]; + + if (HasVoidStructInfo(shape) && HasVoidStructInfo(dtype) && + HasVoidStructInfo(relative_byte_offset)) { + // Special-case, no change is required by the view. + return data; + } + + // Prior to legalization, it is useful to use void-type argument to + // specify "no change". This allows for better shape inference when + // a pass updates the input `data` tensor. However, when we + // legalize the `R.view`, we must provide an explicit parameters. + + if (HasVoidStructInfo(shape)) { + auto data_shape = data->struct_info_.as().value()->GetShape(); + CHECK(data_shape.defined()) + << "Legalization of " << call->op + << " requires that either the output shape be explicitly specified, " + << "or the input shape is known. " + << "However, in expression " << call << ", no output shape is specified, " + << "and the input " << data << " of type " << data->struct_info_ << " has unknown shape."; + shape = ShapeExpr(data_shape.value()); + } + + if (HasVoidStructInfo(dtype)) { + auto data_dtype = data->struct_info_.as().value()->dtype; + CHECK(!data_dtype.is_void()) + << "Legalization of " << call->op + << " requires that either the output dtype be explicitly specified, " + << "or the input dtype is known. " + << "However, in expression " << call << ", no output dtype is specified, " + << "and the input " << data << " of type " << data->struct_info_ << " has unknown dtype."; + dtype = relax::DataTypeImm(data_dtype); + } + + if (HasVoidStructInfo(relative_byte_offset)) { + relative_byte_offset = relax::PrimValue::Int64(0); + } + + StructInfoDeriveFunc infer_sinfo_env_func; + infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); + auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); + + ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + + return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); +} + +TVM_REGISTER_OP("relax.view") + .set_num_inputs(4) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The view's shape.") + .add_argument("dtype", "DataType", "The view's data type.") + .add_argument("relative_byte_offset", "Prim(\"int64\")", + "The view's byte offset, relative to the input tensor's byte offset.") + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FInferStructInfo", InferStructInfoView) + .set_attr("FLegalize", LegalizeView) + .set_attr("FPurity", Bool(true)); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/view.h b/src/relax/op/tensor/view.h new file mode 100644 index 000000000000..f2c6304e809e --- /dev/null +++ b/src/relax/op/tensor/view.h @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file view.h + * \brief The functions to make Relax tensor view calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_VIEW_H_ +#define TVM_RELAX_OP_TENSOR_VIEW_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! \brief View a tensor with different properties. */ +Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_VIEW_H_ diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py new file mode 100644 index 000000000000..8f14cad7b354 --- /dev/null +++ b/tests/python/relax/test_op_view.py @@ -0,0 +1,776 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + +import numpy as np +import pytest + + +def test_infer_shape_of_1d_static_view(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor) -> R.Tensor([4096]): + B: R.Tensor([4096]) = R.view(A, R.shape([4096])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor): + B = R.view(A, R.shape([4096])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_infer_shape_of_2d_static_view(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor) -> R.Tensor([64, 64]): + B: R.Tensor([64, 64]) = R.view(A, R.shape([64, 64])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor): + B = R.view(A, R.shape([64, 64])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_error_if_shape_argument_is_not_shape(): + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor([16])): + B = R.view(A, R.prim_value(42)) + return B + + +def test_infer_shape_of_1d_static_view_smaller_than_1d_source(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor([4096])) -> R.Tensor([16]): + B: R.Tensor([16]) = R.view(A, R.shape([16])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor([4096])): + B = R.view(A, R.shape([16])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_infer_shape_of_2d_static_view_smaller_than_1d_source(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor([4096])) -> R.Tensor([4, 4]): + B: R.Tensor([4, 4]) = R.view(A, R.shape([4, 4])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor([4096])): + B = R.view(A, R.shape([4, 4])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_infer_shape_of_2d_static_view_same_size_as_2d_source(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor([64, 64])) -> R.Tensor([16, 256]): + B: R.Tensor([16, 256]) = R.view(A, R.shape([16, 256])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor([64, 64])): + B = R.view(A, R.shape([16, 256])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_error_if_1d_static_view_larger_than_1d_source(): + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor([16])): + B = R.view(A, R.shape([17])) + return B + + +def test_error_if_static_2d_view_larger_than_source(): + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor([16])): + B = R.view(A, R.shape([4, 5])) + return B + + +def test_infer_shape_of_1d_dynamic_view(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): + N = T.int64() + B: R.Tensor([N // 2]) = R.view(A, R.shape([N // 2])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor(["N"])): + N = T.int64() + B = R.view(A, R.shape([N // 2])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_infer_shape_of_2d_dynamic_view_of_1d_source(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 8", 8]): + N = T.int64() + B: R.Tensor([N // 8, 8]) = R.view(A, R.shape([N // 8, 8])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor(["N"])): + N = T.int64() + B = R.view(A, R.shape([N // 8, 8])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_infer_shape_of_2d_dynamic_view(): + @R.function(private=True) + def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): + N = T.int64() + B: R.Tensor([N // 2]) = R.view(A, R.shape([N // 2])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor(["N"])): + N = T.int64() + B = R.view(A, R.shape([N // 2])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_error_if_1d_dynamic_view_larger_than_1d_source(): + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor(["N"])): + N = T.int64() + B = R.view(A, R.shape([N + 1])) + return B + + +@pytest.mark.xfail(reason="See https://github.com/apache/tvm/pull/16877") +def test_error_if_1d_dynamic_view_provably_larger_than_1d_source(): + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor(["N"])): + N = T.int64() + B = R.view(A, R.shape([N + T.if_then_else(N < 0, -1, 1)])) + return B + + +def test_error_if_2d_dynamic_view_provably_larger_than_1d_source(): + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor(["N"])): + N = T.int64() + B = R.view(A, R.shape([N // 4 + 1, 4])) + return B + + +def test_validity_of_dynamic_view_may_depend_on_runtime_value(): + """Validity checks may be delayed until runtime + + The runtime implementation of `R.view` checks the validity of any + dynamic shape. A compile-time error should only be issued the + runtime check would fail for *all* dynamic shapes. + + In this example, the output of `R.view` contains `N` elements when + `N` is evenly divisible by 4, and `N+4` elements otherwise. The + runtime check would pass whenever the argument's size is divisible + by 4. Even though the runtime check would fail when `N` isn't + divisible by 4, no compile-time error should be emitted. + + """ + + @R.function + def func(A: R.Tensor(["N"])): + N = T.int64() + B = R.view(A, R.shape([(N + 3) // 4, 4])) + return B + + +def test_infer_dtype_of_float32_view(): + """R.view can reinterpret the contents as another type + + For example, if the same backing allocation is used for multiple + arrays with distinct datatypes. + + """ + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor) -> R.Tensor("float32"): + B: R.Tensor("float32") = R.view(A, dtype=R.dtype("float32")) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor): + B = R.view(A, dtype=R.dtype("float32")) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_view_without_explicit_dtype_keeps_input_dtype(): + """If R.view only specifies the shape, the dtype is unchanged""" + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor([16], "float32")) -> R.Tensor([4, 4], "float32"): + B: R.Tensor([4, 4], "float32") = R.view(A, R.shape([4, 4])) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor([16], "float32")): + B = R.view(A, R.shape([4, 4])) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_infer_dtype_of_float32_view_from_relax_var(): + """R.view can reinterpret the contents as another type + + Any relax object can be stored in a relax variable. Even if the + `R.dtype` argument is stored in a variable, struct inference may + be applied. + + """ + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor) -> R.Tensor("float32"): + dtype = R.dtype("float32") + B: R.Tensor("float32") = R.view(A, dtype=dtype) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor): + dtype = R.dtype("float32") + B = R.view(A, dtype=dtype) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_infer_dtype_of_view_with_unknown_dtype(): + """DType may be provided as argument + + Because we do not know the value provided in `dtype`, the element + type of the array is unknown. + + """ + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor("float32"), dtype: R.Object) -> R.Tensor: + B: R.Tensor = R.view(A, dtype=dtype) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor("float32"), dtype: R.Object): + B = R.view(A, dtype=dtype) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_view_dtype_may_be_smaller_than_input_dtype(): + """Viewing with a smaller dtype does not exceed original bounds + + This is not typically desired behavior, as the view would span + fewer bytes than the original array. However, this is legal, and + may occur as the result of optimization passes. + + """ + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor("uint32")) -> R.Tensor("float8"): + B: R.Tensor("float8") = R.view(A, dtype=R.dtype("float8")) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor("uint32")): + B = R.view(A, dtype=R.dtype("float8")) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_error_if_view_dtype_is_larger_than_input_dtype(): + """A view may not exceed the bounds of the viewed array""" + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor([16], "uint8")): + B = R.view(A, dtype=R.dtype("float16")) + return B + + +def test_increase_dtype_size_while_decreasing_number_of_elements(): + """R.view may update both dtype and shape simultaneously + + Like `test_error_if_dtype_results_in_larger_view`, but the view + contains fewer elements than the backing array. This results in a + view that is the same size as the backing array, and would not + exceed the bounds of the original array. + + """ + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor([16], "uint8")) -> R.Tensor([8], "float16"): + B: R.Tensor([8], "float16") = R.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor([16], "uint8")): + B = R.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_decrease_dtype_size_while_increasing_number_of_elements(): + """R.view may update both dtype and shape simultaneously""" + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor([8], "float16")) -> R.Tensor([16], "uint8"): + B: R.Tensor([16], "uint8") = R.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor([8], "float16")): + B = R.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_error_if_number_of_bytes_of_view_is_larger_than_original(): + """R.view may update both dtype and shape simultaneously + + In this test case, the source array is 16 bytes (8 elements * 2 + bytes/element), but the view is 32 bytes (32 elements * 1 + byte/element). + + """ + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor([8], "float16")): + B = R.view(A, shape=R.shape([32]), dtype=R.dtype("uint8")) + return B + + +def test_error_for_non_zero_relative_byte_offset(): + """R.view must not exceed bounds of the original array + + Providing a non-zero `relative_byte_offset`, without updating + either the dtype or the shape of the array, would allow the view + to overrun the end of the original array. + + """ + + with pytest.raises(tvm.TVMError): + + @R.function + def func(A: R.Tensor): + B = R.view(A, relative_byte_offset=16) + return B + + +def test_applying_relative_byte_offset_of_zero_is_legal(): + """Using relative_byte_offset=0 is no-op + + Providing a `relative_byte_offset` of zero, without updating + either the dtype or the shape of the array, is legal, though it is + a no-op. + + """ + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor) -> R.Tensor: + B: R.Tensor = R.view(A, relative_byte_offset=R.prim_value(0)) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor): + B = R.view(A, relative_byte_offset=R.prim_value(0)) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_applying_unknown_relative_byte_offset_is_legal(): + """Using an unknown relative_byte_offset is legal + + Since providing a `relative_byte_offset` of zero, without updating + either the dtype or the shape of the array, is legal, we may not + emit a compile-time error for an unknown `relative_byte_offset` in + this case. + + """ + + @R.function(private=True) + def explicit_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")) -> R.Tensor: + B: R.Tensor = R.view(A, relative_byte_offset=relative_byte_offset) + return B + + @R.function(private=True) + def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")): + B = R.view(A, relative_byte_offset=relative_byte_offset) + return B + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + +def test_legalize_without_any_changes_is_no_op(): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view(A) + return B + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = A + return B + + After = tvm.relax.transform.LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_legalize_shape_change(): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view(A, shape=R.shape([64, 64])) + return B + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([64, 64]), + R.dtype("float32"), + R.prim_value(0), + ) + return B + + After = tvm.relax.transform.LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_legalize_view_shape_from_unknown(): + """R.view does not require the input tensor to have a known shape""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor(dtype="float32")): + B = R.view(A, shape=R.shape([64, 64])) + return B + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor(dtype="float32")): + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([64, 64]), + R.dtype("float32"), + R.prim_value(0), + ) + return B + + After = tvm.relax.transform.LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_legalize_dtype_change(): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view(A, dtype=R.dtype("int32")) + return B + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([4096]), + R.dtype("int32"), + R.prim_value(0), + ) + return B + + After = tvm.relax.transform.LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_legalize_byte_offset(): + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view(A, relative_byte_offset=R.prim_value(0)) + return B + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([4096]), + R.dtype("float32"), + R.prim_value(0), + ) + return B + + After = tvm.relax.transform.LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_legalize_view_with_multiple_updated_fields(): + """R.view may update more than one field in the view + + In this test case, a 4-kilobyte buffer is provided. The first + 2-kilobytes of the buffer are used as a 1-d array of 512 int32. + The last 2-kilobytes of the buffer are used as a 2-d array of + [16,64] float16 values. + + """ + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([4096], "uint8")): + B = R.view( + A, + shape=R.shape([512]), + dtype=R.dtype("int32"), + ) + C = R.view( + A, + shape=R.shape([16, 64]), + dtype=R.dtype("float16"), + relative_byte_offset=R.prim_value(2048), + ) + return (B, C) + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([4096], "uint8")): + B = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([512]), + R.dtype("int32"), + R.prim_value(0), + ) + C = R.ExternFunc( + "runtime.TVMArrayCreateView", + R.Callable( + derive_func="tvm.relax.struct_info.infer_view_sinfo", + purity=True, + ), + )( + A, + R.shape([16, 64]), + R.dtype("float16"), + R.prim_value(2048), + ) + return (B, C) + + After = tvm.relax.transform.LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_execute_no_op_view(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view(A) + return B + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, device=dev) + + np_input = np.random.random([4096]).astype("float32") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_execute_view_with_new_shape(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view(A, shape=R.shape([64, 64])) + return B + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, device=dev) + + np_input = np.random.random([4096]).astype("float32") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.reshape(64, 64) + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_execute_view_with_new_byte_offset(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view( + A, + shape=R.shape([16, 64]), + relative_byte_offset=32 * 64 * 4, + ) + return B + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, device=dev) + + np_input = np.random.random([4096]).astype("float32") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.reshape(64, 64)[32:48, :] + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_execute_view_with_new_dtype(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([4096], "float32")): + B = R.view(A, dtype="uint32") + return B + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, device=dev) + + np_input = np.random.random([4096]).astype("float32") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = np_input.view("uint32") + + tvm.testing.assert_allclose(tvm_output.numpy(), np_expected) + + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_execute_view_with_multiple_updated_fields(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([4096], "uint8")): + B = R.view( + A, + shape=R.shape([512]), + dtype=R.dtype("int32"), + ) + C = R.view( + A, + shape=R.shape([16, 64]), + dtype=R.dtype("float16"), + relative_byte_offset=R.prim_value(2048), + ) + return (B, C) + + built = tvm.relax.build(Module, target=target) + vm = tvm.relax.VirtualMachine(built, device=dev) + + np_input = np.random.randint(0, 255, size=[4096]).astype("uint8") + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + np_expected = [ + np_input[:2048].view("int32"), + np_input[2048:].view("float16").reshape(16, 64), + ] + + tvm.testing.assert_allclose(tvm_output[0].numpy(), np_expected[0]) + tvm.testing.assert_allclose(tvm_output[1].numpy(), np_expected[1]) + + +if __name__ == "__main__": + tvm.testing.main() From fb7e615a975f4f8c4e192ac7a58473e49933d543 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Apr 2024 15:26:58 -0500 Subject: [PATCH 2/3] Move view operation to be in the "memory" group - Rename `R.view` to `R.memory.view` - Rename `relax.op.view` to `relax.op.memory.view` --- python/tvm/relax/op/__init__.py | 1 - python/tvm/relax/op/memory/__init__.py | 1 + python/tvm/relax/op/{ => memory}/view.py | 5 +- python/tvm/script/ir_builder/relax/ir.py | 2 - python/tvm/script/parser/relax/parser.py | 2 +- src/relax/op/{tensor => memory}/view.cc | 10 +- src/relax/op/{tensor => memory}/view.h | 6 +- tests/python/relax/test_op_view.py | 136 +++++++++++------------ 8 files changed, 81 insertions(+), 82 deletions(-) rename python/tvm/relax/op/{ => memory}/view.py (96%) rename src/relax/op/{tensor => memory}/view.cc (98%) rename src/relax/op/{tensor => memory}/view.h (91%) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 760af936da8c..5b585e18b450 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -135,7 +135,6 @@ tan, tanh, ) -from .view import view def _register_op_make(): diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py index 45819f4cb395..422c5d2e1f53 100644 --- a/python/tvm/relax/op/memory/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -17,3 +17,4 @@ """Relax memory primitives.""" from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor +from .view import view diff --git a/python/tvm/relax/op/view.py b/python/tvm/relax/op/memory/view.py similarity index 96% rename from python/tvm/relax/op/view.py rename to python/tvm/relax/op/memory/view.py index 65102baad141..cf7e14f53f49 100644 --- a/python/tvm/relax/op/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -18,10 +18,11 @@ """Operations that act on the DLTensor container """ from typing import Optional, Sequence, Union -from tvm.ir.expr import PrimExpr +from tvm.tir import PrimExpr +from tvm.relax import Expr, ShapeExpr, DataTypeImm, PrimValue from . import _ffi_api -from ..expr import Expr, PrimValue, ShapeExpr, DataTypeImm + PrimExprLike = Union[int, PrimExpr] diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 84032ec72c65..6dbf5c5dfdb4 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -133,7 +133,6 @@ sum, take, variance, - view, sigmoid, sign, sin, @@ -795,7 +794,6 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "tuple", "unique", "variance", - "view", "vm", "vpi", "vulkan", diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 9d73749b0aa4..400c023aa7e8 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -108,7 +108,7 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St struct_info = self.eval_expr(node) return _normalize_struct_info(struct_info, var_table) except Exception as err: - self.report_error(node, str(err)) + self.report_error(node, err) raise err diff --git a/src/relax/op/tensor/view.cc b/src/relax/op/memory/view.cc similarity index 98% rename from src/relax/op/tensor/view.cc rename to src/relax/op/memory/view.cc index 78a6ae4aa5b1..e7634c7edfce 100644 --- a/src/relax/op/tensor/view.cc +++ b/src/relax/op/memory/view.cc @@ -19,7 +19,7 @@ /*! * \file view.cc - * \brief Manipulation operators. + * \brief Operator to view an existing tensor. */ #include "view.h" @@ -27,11 +27,11 @@ namespace tvm { namespace relax { -/* relax.broadcast_to */ +/* relax.op.memory.view */ Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset) { Tuple void_expr(Array{}); - static const Op& op = Op::Get("relax.view"); + static const Op& op = Op::Get("relax.memory.view"); return Call(op, { x, shape.value_or(void_expr), @@ -40,7 +40,7 @@ Expr view(Expr x, Optional shape, Optional dtype, Optional rel }); } -TVM_REGISTER_GLOBAL("relax.op.view").set_body_typed(view); +TVM_REGISTER_GLOBAL("relax.op.memory.view").set_body_typed(view); StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 4) { @@ -343,7 +343,7 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { return Call(runtime_view_func, {data, shape, dtype, relative_byte_offset}); } -TVM_REGISTER_OP("relax.view") +TVM_REGISTER_OP("relax.memory.view") .set_num_inputs(4) .add_argument("x", "Tensor", "The input tensor.") .add_argument("shape", "Shape", "The view's shape.") diff --git a/src/relax/op/tensor/view.h b/src/relax/op/memory/view.h similarity index 91% rename from src/relax/op/tensor/view.h rename to src/relax/op/memory/view.h index f2c6304e809e..bc8002fa5b69 100644 --- a/src/relax/op/tensor/view.h +++ b/src/relax/op/memory/view.h @@ -21,8 +21,8 @@ * \file view.h * \brief The functions to make Relax tensor view calls. */ -#ifndef TVM_RELAX_OP_TENSOR_VIEW_H_ -#define TVM_RELAX_OP_TENSOR_VIEW_H_ +#ifndef TVM_RELAX_OP_MEMORY_VIEW_H_ +#define TVM_RELAX_OP_MEMORY_VIEW_H_ #include "../op_common.h" @@ -35,4 +35,4 @@ Expr view(Expr x, Optional shape, Optional dtype, Optional rel } // namespace relax } // namespace tvm -#endif // TVM_RELAX_OP_TENSOR_VIEW_H_ +#endif // TVM_RELAX_OP_MEMORY_VIEW_H_ diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 8f14cad7b354..2433821c2abd 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -26,12 +26,12 @@ def test_infer_shape_of_1d_static_view(): @R.function(private=True) def explicit_sinfo(A: R.Tensor) -> R.Tensor([4096]): - B: R.Tensor([4096]) = R.view(A, R.shape([4096])) + B: R.Tensor([4096]) = R.memory.view(A, R.shape([4096])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor): - B = R.view(A, R.shape([4096])) + B = R.memory.view(A, R.shape([4096])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -40,12 +40,12 @@ def inferred_sinfo(A: R.Tensor): def test_infer_shape_of_2d_static_view(): @R.function(private=True) def explicit_sinfo(A: R.Tensor) -> R.Tensor([64, 64]): - B: R.Tensor([64, 64]) = R.view(A, R.shape([64, 64])) + B: R.Tensor([64, 64]) = R.memory.view(A, R.shape([64, 64])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor): - B = R.view(A, R.shape([64, 64])) + B = R.memory.view(A, R.shape([64, 64])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -56,19 +56,19 @@ def test_error_if_shape_argument_is_not_shape(): @R.function def func(A: R.Tensor([16])): - B = R.view(A, R.prim_value(42)) + B = R.memory.view(A, R.prim_value(42)) return B def test_infer_shape_of_1d_static_view_smaller_than_1d_source(): @R.function(private=True) def explicit_sinfo(A: R.Tensor([4096])) -> R.Tensor([16]): - B: R.Tensor([16]) = R.view(A, R.shape([16])) + B: R.Tensor([16]) = R.memory.view(A, R.shape([16])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor([4096])): - B = R.view(A, R.shape([16])) + B = R.memory.view(A, R.shape([16])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -77,12 +77,12 @@ def inferred_sinfo(A: R.Tensor([4096])): def test_infer_shape_of_2d_static_view_smaller_than_1d_source(): @R.function(private=True) def explicit_sinfo(A: R.Tensor([4096])) -> R.Tensor([4, 4]): - B: R.Tensor([4, 4]) = R.view(A, R.shape([4, 4])) + B: R.Tensor([4, 4]) = R.memory.view(A, R.shape([4, 4])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor([4096])): - B = R.view(A, R.shape([4, 4])) + B = R.memory.view(A, R.shape([4, 4])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -91,12 +91,12 @@ def inferred_sinfo(A: R.Tensor([4096])): def test_infer_shape_of_2d_static_view_same_size_as_2d_source(): @R.function(private=True) def explicit_sinfo(A: R.Tensor([64, 64])) -> R.Tensor([16, 256]): - B: R.Tensor([16, 256]) = R.view(A, R.shape([16, 256])) + B: R.Tensor([16, 256]) = R.memory.view(A, R.shape([16, 256])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor([64, 64])): - B = R.view(A, R.shape([16, 256])) + B = R.memory.view(A, R.shape([16, 256])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -107,7 +107,7 @@ def test_error_if_1d_static_view_larger_than_1d_source(): @R.function def func(A: R.Tensor([16])): - B = R.view(A, R.shape([17])) + B = R.memory.view(A, R.shape([17])) return B @@ -116,7 +116,7 @@ def test_error_if_static_2d_view_larger_than_source(): @R.function def func(A: R.Tensor([16])): - B = R.view(A, R.shape([4, 5])) + B = R.memory.view(A, R.shape([4, 5])) return B @@ -124,13 +124,13 @@ def test_infer_shape_of_1d_dynamic_view(): @R.function(private=True) def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): N = T.int64() - B: R.Tensor([N // 2]) = R.view(A, R.shape([N // 2])) + B: R.Tensor([N // 2]) = R.memory.view(A, R.shape([N // 2])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor(["N"])): N = T.int64() - B = R.view(A, R.shape([N // 2])) + B = R.memory.view(A, R.shape([N // 2])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -140,13 +140,13 @@ def test_infer_shape_of_2d_dynamic_view_of_1d_source(): @R.function(private=True) def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 8", 8]): N = T.int64() - B: R.Tensor([N // 8, 8]) = R.view(A, R.shape([N // 8, 8])) + B: R.Tensor([N // 8, 8]) = R.memory.view(A, R.shape([N // 8, 8])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor(["N"])): N = T.int64() - B = R.view(A, R.shape([N // 8, 8])) + B = R.memory.view(A, R.shape([N // 8, 8])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -156,13 +156,13 @@ def test_infer_shape_of_2d_dynamic_view(): @R.function(private=True) def explicit_sinfo(A: R.Tensor(["N"])) -> R.Tensor(["N // 2"]): N = T.int64() - B: R.Tensor([N // 2]) = R.view(A, R.shape([N // 2])) + B: R.Tensor([N // 2]) = R.memory.view(A, R.shape([N // 2])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor(["N"])): N = T.int64() - B = R.view(A, R.shape([N // 2])) + B = R.memory.view(A, R.shape([N // 2])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -174,7 +174,7 @@ def test_error_if_1d_dynamic_view_larger_than_1d_source(): @R.function def func(A: R.Tensor(["N"])): N = T.int64() - B = R.view(A, R.shape([N + 1])) + B = R.memory.view(A, R.shape([N + 1])) return B @@ -185,7 +185,7 @@ def test_error_if_1d_dynamic_view_provably_larger_than_1d_source(): @R.function def func(A: R.Tensor(["N"])): N = T.int64() - B = R.view(A, R.shape([N + T.if_then_else(N < 0, -1, 1)])) + B = R.memory.view(A, R.shape([N + T.if_then_else(N < 0, -1, 1)])) return B @@ -195,18 +195,18 @@ def test_error_if_2d_dynamic_view_provably_larger_than_1d_source(): @R.function def func(A: R.Tensor(["N"])): N = T.int64() - B = R.view(A, R.shape([N // 4 + 1, 4])) + B = R.memory.view(A, R.shape([N // 4 + 1, 4])) return B def test_validity_of_dynamic_view_may_depend_on_runtime_value(): """Validity checks may be delayed until runtime - The runtime implementation of `R.view` checks the validity of any + The runtime implementation of `R.memory.view` checks the validity of any dynamic shape. A compile-time error should only be issued the runtime check would fail for *all* dynamic shapes. - In this example, the output of `R.view` contains `N` elements when + In this example, the output of `R.memory.view` contains `N` elements when `N` is evenly divisible by 4, and `N+4` elements otherwise. The runtime check would pass whenever the argument's size is divisible by 4. Even though the runtime check would fail when `N` isn't @@ -217,12 +217,12 @@ def test_validity_of_dynamic_view_may_depend_on_runtime_value(): @R.function def func(A: R.Tensor(["N"])): N = T.int64() - B = R.view(A, R.shape([(N + 3) // 4, 4])) + B = R.memory.view(A, R.shape([(N + 3) // 4, 4])) return B def test_infer_dtype_of_float32_view(): - """R.view can reinterpret the contents as another type + """R.memory.view can reinterpret the contents as another type For example, if the same backing allocation is used for multiple arrays with distinct datatypes. @@ -231,35 +231,35 @@ def test_infer_dtype_of_float32_view(): @R.function(private=True) def explicit_sinfo(A: R.Tensor) -> R.Tensor("float32"): - B: R.Tensor("float32") = R.view(A, dtype=R.dtype("float32")) + B: R.Tensor("float32") = R.memory.view(A, dtype=R.dtype("float32")) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor): - B = R.view(A, dtype=R.dtype("float32")) + B = R.memory.view(A, dtype=R.dtype("float32")) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) def test_view_without_explicit_dtype_keeps_input_dtype(): - """If R.view only specifies the shape, the dtype is unchanged""" + """If R.memory.view only specifies the shape, the dtype is unchanged""" @R.function(private=True) def explicit_sinfo(A: R.Tensor([16], "float32")) -> R.Tensor([4, 4], "float32"): - B: R.Tensor([4, 4], "float32") = R.view(A, R.shape([4, 4])) + B: R.Tensor([4, 4], "float32") = R.memory.view(A, R.shape([4, 4])) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor([16], "float32")): - B = R.view(A, R.shape([4, 4])) + B = R.memory.view(A, R.shape([4, 4])) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) def test_infer_dtype_of_float32_view_from_relax_var(): - """R.view can reinterpret the contents as another type + """R.memory.view can reinterpret the contents as another type Any relax object can be stored in a relax variable. Even if the `R.dtype` argument is stored in a variable, struct inference may @@ -270,13 +270,13 @@ def test_infer_dtype_of_float32_view_from_relax_var(): @R.function(private=True) def explicit_sinfo(A: R.Tensor) -> R.Tensor("float32"): dtype = R.dtype("float32") - B: R.Tensor("float32") = R.view(A, dtype=dtype) + B: R.Tensor("float32") = R.memory.view(A, dtype=dtype) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor): dtype = R.dtype("float32") - B = R.view(A, dtype=dtype) + B = R.memory.view(A, dtype=dtype) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -292,12 +292,12 @@ def test_infer_dtype_of_view_with_unknown_dtype(): @R.function(private=True) def explicit_sinfo(A: R.Tensor("float32"), dtype: R.Object) -> R.Tensor: - B: R.Tensor = R.view(A, dtype=dtype) + B: R.Tensor = R.memory.view(A, dtype=dtype) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor("float32"), dtype: R.Object): - B = R.view(A, dtype=dtype) + B = R.memory.view(A, dtype=dtype) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -314,12 +314,12 @@ def test_view_dtype_may_be_smaller_than_input_dtype(): @R.function(private=True) def explicit_sinfo(A: R.Tensor("uint32")) -> R.Tensor("float8"): - B: R.Tensor("float8") = R.view(A, dtype=R.dtype("float8")) + B: R.Tensor("float8") = R.memory.view(A, dtype=R.dtype("float8")) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor("uint32")): - B = R.view(A, dtype=R.dtype("float8")) + B = R.memory.view(A, dtype=R.dtype("float8")) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -331,12 +331,12 @@ def test_error_if_view_dtype_is_larger_than_input_dtype(): @R.function def func(A: R.Tensor([16], "uint8")): - B = R.view(A, dtype=R.dtype("float16")) + B = R.memory.view(A, dtype=R.dtype("float16")) return B def test_increase_dtype_size_while_decreasing_number_of_elements(): - """R.view may update both dtype and shape simultaneously + """R.memory.view may update both dtype and shape simultaneously Like `test_error_if_dtype_results_in_larger_view`, but the view contains fewer elements than the backing array. This results in a @@ -347,35 +347,35 @@ def test_increase_dtype_size_while_decreasing_number_of_elements(): @R.function(private=True) def explicit_sinfo(A: R.Tensor([16], "uint8")) -> R.Tensor([8], "float16"): - B: R.Tensor([8], "float16") = R.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) + B: R.Tensor([8], "float16") = R.memory.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor([16], "uint8")): - B = R.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) + B = R.memory.view(A, shape=R.shape([8]), dtype=R.dtype("float16")) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) def test_decrease_dtype_size_while_increasing_number_of_elements(): - """R.view may update both dtype and shape simultaneously""" + """R.memory.view may update both dtype and shape simultaneously""" @R.function(private=True) def explicit_sinfo(A: R.Tensor([8], "float16")) -> R.Tensor([16], "uint8"): - B: R.Tensor([16], "uint8") = R.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) + B: R.Tensor([16], "uint8") = R.memory.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor([8], "float16")): - B = R.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) + B = R.memory.view(A, shape=R.shape([16]), dtype=R.dtype("uint8")) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) def test_error_if_number_of_bytes_of_view_is_larger_than_original(): - """R.view may update both dtype and shape simultaneously + """R.memory.view may update both dtype and shape simultaneously In this test case, the source array is 16 bytes (8 elements * 2 bytes/element), but the view is 32 bytes (32 elements * 1 @@ -386,12 +386,12 @@ def test_error_if_number_of_bytes_of_view_is_larger_than_original(): @R.function def func(A: R.Tensor([8], "float16")): - B = R.view(A, shape=R.shape([32]), dtype=R.dtype("uint8")) + B = R.memory.view(A, shape=R.shape([32]), dtype=R.dtype("uint8")) return B def test_error_for_non_zero_relative_byte_offset(): - """R.view must not exceed bounds of the original array + """R.memory.view must not exceed bounds of the original array Providing a non-zero `relative_byte_offset`, without updating either the dtype or the shape of the array, would allow the view @@ -403,7 +403,7 @@ def test_error_for_non_zero_relative_byte_offset(): @R.function def func(A: R.Tensor): - B = R.view(A, relative_byte_offset=16) + B = R.memory.view(A, relative_byte_offset=16) return B @@ -418,12 +418,12 @@ def test_applying_relative_byte_offset_of_zero_is_legal(): @R.function(private=True) def explicit_sinfo(A: R.Tensor) -> R.Tensor: - B: R.Tensor = R.view(A, relative_byte_offset=R.prim_value(0)) + B: R.Tensor = R.memory.view(A, relative_byte_offset=R.prim_value(0)) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor): - B = R.view(A, relative_byte_offset=R.prim_value(0)) + B = R.memory.view(A, relative_byte_offset=R.prim_value(0)) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -441,12 +441,12 @@ def test_applying_unknown_relative_byte_offset_is_legal(): @R.function(private=True) def explicit_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")) -> R.Tensor: - B: R.Tensor = R.view(A, relative_byte_offset=relative_byte_offset) + B: R.Tensor = R.memory.view(A, relative_byte_offset=relative_byte_offset) return B @R.function(private=True) def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")): - B = R.view(A, relative_byte_offset=relative_byte_offset) + B = R.memory.view(A, relative_byte_offset=relative_byte_offset) return B tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) @@ -457,7 +457,7 @@ def test_legalize_without_any_changes_is_no_op(): class Before: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view(A) + B = R.memory.view(A) return B @I.ir_module @@ -476,7 +476,7 @@ def test_legalize_shape_change(): class Before: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view(A, shape=R.shape([64, 64])) + B = R.memory.view(A, shape=R.shape([64, 64])) return B @I.ir_module @@ -502,13 +502,13 @@ def main(A: R.Tensor([4096], "float32")): def test_legalize_view_shape_from_unknown(): - """R.view does not require the input tensor to have a known shape""" + """R.memory.view does not require the input tensor to have a known shape""" @I.ir_module class Before: @R.function def main(A: R.Tensor(dtype="float32")): - B = R.view(A, shape=R.shape([64, 64])) + B = R.memory.view(A, shape=R.shape([64, 64])) return B @I.ir_module @@ -538,7 +538,7 @@ def test_legalize_dtype_change(): class Before: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view(A, dtype=R.dtype("int32")) + B = R.memory.view(A, dtype=R.dtype("int32")) return B @I.ir_module @@ -568,7 +568,7 @@ def test_legalize_byte_offset(): class Before: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view(A, relative_byte_offset=R.prim_value(0)) + B = R.memory.view(A, relative_byte_offset=R.prim_value(0)) return B @I.ir_module @@ -594,7 +594,7 @@ def main(A: R.Tensor([4096], "float32")): def test_legalize_view_with_multiple_updated_fields(): - """R.view may update more than one field in the view + """R.memory.view may update more than one field in the view In this test case, a 4-kilobyte buffer is provided. The first 2-kilobytes of the buffer are used as a 1-d array of 512 int32. @@ -607,12 +607,12 @@ def test_legalize_view_with_multiple_updated_fields(): class Before: @R.function def main(A: R.Tensor([4096], "uint8")): - B = R.view( + B = R.memory.view( A, shape=R.shape([512]), dtype=R.dtype("int32"), ) - C = R.view( + C = R.memory.view( A, shape=R.shape([16, 64]), dtype=R.dtype("float16"), @@ -660,7 +660,7 @@ def test_execute_no_op_view(target, dev): class Module: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view(A) + B = R.memory.view(A) return B built = tvm.relax.build(Module, target=target) @@ -680,7 +680,7 @@ def test_execute_view_with_new_shape(target, dev): class Module: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view(A, shape=R.shape([64, 64])) + B = R.memory.view(A, shape=R.shape([64, 64])) return B built = tvm.relax.build(Module, target=target) @@ -700,7 +700,7 @@ def test_execute_view_with_new_byte_offset(target, dev): class Module: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view( + B = R.memory.view( A, shape=R.shape([16, 64]), relative_byte_offset=32 * 64 * 4, @@ -724,7 +724,7 @@ def test_execute_view_with_new_dtype(target, dev): class Module: @R.function def main(A: R.Tensor([4096], "float32")): - B = R.view(A, dtype="uint32") + B = R.memory.view(A, dtype="uint32") return B built = tvm.relax.build(Module, target=target) @@ -744,12 +744,12 @@ def test_execute_view_with_multiple_updated_fields(target, dev): class Module: @R.function def main(A: R.Tensor([4096], "uint8")): - B = R.view( + B = R.memory.view( A, shape=R.shape([512]), dtype=R.dtype("int32"), ) - C = R.view( + C = R.memory.view( A, shape=R.shape([16, 64]), dtype=R.dtype("float16"), From 3da5e8600b7aece7f6b5379aabe66a93073ee7d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 8 May 2024 11:02:47 -0500 Subject: [PATCH 3/3] Updates based on review comments --- python/tvm/relax/op/memory/view.py | 21 +++++++++++++++++++-- python/tvm/script/parser/relax/entry.py | 6 +++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index cf7e14f53f49..0c3d8a03b2dd 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -15,7 +15,16 @@ # specific language governing permissions and limitations # under the License. -"""Operations that act on the DLTensor container """ +"""Operations that act on the DLTensor container + +While most operations require inspecting the values stored within the +allocated buffers, some operations only require updating the fields in +a `DLTensor`, without touching the values that are stored within it. +For example, given an array of shape `[16,16]`, the slice at +`[0:8,0:16]` can be generated by changing the `DLTensor::shape` field, +while keeping the same underlying data. + +""" from typing import Optional, Sequence, Union from tvm.tir import PrimExpr @@ -33,7 +42,15 @@ def view( dtype: Optional[Expr] = None, relative_byte_offset: Optional[Expr] = None, ) -> Expr: - """Broadcasts a tensor to a specified shape. + """Provide a view into an existing tensor + + The view may have a different shape, may be a different datatype, + and may start at an offset relative to the source array. + + Regardless of which combination of these options are used, the + view may never access memory that was not accessible through the + input `data` array. This restriction applies even if the `data` + array is itself a view into a shared backing array. Parameters ---------- diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 7ede3d14fda7..73a5d7149a81 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -299,7 +299,11 @@ class CallableProxy(StructInfoProxy): Whether the callable is pure. derive_func: Optional[Union[str, tvm.ir.EnvFunc]] - The derivation function for the outputq + The derivation function to determine the output StructInfo, + based on the arguments provided to the function. The + specified function should be accessible using + `tvm.get_global_func`, and should have a signature + `Callable[[relax.Call, relax.BlockBuilder], relax.StructInfo]`. """