diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 7e51264cb37c..acb490a813b8 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -32,6 +32,7 @@ TensorStructInfo, TupleStructInfo, ) +from tvm.relax.expr import Var from tvm.runtime import ObjectGeneric from tvm.tir import PrimExpr @@ -89,15 +90,20 @@ def __init__( dtype: Optional[str] = None, ndim: int = -1, ) -> None: + if isinstance(shape, Expr): + if not isinstance(shape, (ShapeExpr, Var)): + raise ValueError( + "When the shape is an Expr, it must be a ShapeExpr or a Var with ShapeExpr " + f"value. But got: {shape} with type: {type(shape)}" + ) + if isinstance(shape, Var) and not isinstance(shape.struct_info, ShapeStructInfo): + raise ValueError( + "When the shape is a Var, it must have shape struct_info. But got " + f"{shape} with struct_info: {shape.struct_info}" + ) self.shape = shape - if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr): - raise ValueError( - "Only ShapeExpr is allowed as shape expr, but got: " - f"{shape} with type: {type(shape)}" - ) self.dtype = dtype self.ndim = ndim - super().__init__() def get_symbolic_vars(self) -> Set[str]: if self.shape is None or isinstance(self.shape, Expr): @@ -108,7 +114,7 @@ def get_symbolic_vars(self) -> Set[str]: def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: if self.shape is None: return TensorStructInfo(None, self.dtype, self.ndim) - elif isinstance(self.shape, ShapeExpr): + elif isinstance(self.shape, (ShapeExpr, Var)): return TensorStructInfo(self.shape, self.dtype, self.ndim) else: if dict_globals is None and any([isinstance(s, str) for s in self.shape]): @@ -121,23 +127,19 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso def Tensor( - shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> TensorProxy: # scalar tensor case - if shape is not None and len(shape) == 0: + if shape is not None and not isinstance(shape, Var) and len(shape) == 0: shape = [] if isinstance(shape, str) and dtype is None: dtype = shape shape = None - if ( - shape is not None - and not isinstance(shape, (tuple, list)) - and not isinstance(shape, ShapeExpr) - ): - raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}") + if shape is not None and not isinstance(shape, (tuple, list)) and not isinstance(shape, Expr): + raise ValueError(f"shape must be a list/tuple or an Expr, but got: {shape}") return TensorProxy(shape, dtype, ndim) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index e8c4c9de2aa3..9c8084e9d062 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -747,6 +747,7 @@ def foo( q: R.Tensor(ndim=2) = R.add(w, w) t = R.add(w, z) sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) + lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh) o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) return o @@ -759,13 +760,15 @@ def _check_struct_info(binding, expected_sinfo): assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) m = relax.get_shape_of(foo.params[0])[1] bindings = foo.body.blocks[0].bindings + sh = bindings[4].var _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) - _check_struct_info(bindings[5], relax.ObjectStructInfo()) + _check_struct_info(bindings[5], relax.TensorStructInfo(sh)) + _check_struct_info(bindings[6], relax.ObjectStructInfo()) def test_annotate_override():