From 4e7b8b540c564adbdf24a9d919ef2b37e78da51f Mon Sep 17 00:00:00 2001 From: ubospica Date: Sat, 25 Mar 2023 15:39:58 +0000 Subject: [PATCH 1/2] finished --- python/tvm/script/parser/relax/entry.py | 33 +++++++++++---------- tests/python/relax/test_tvmscript_parser.py | 5 +++- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 7e51264cb37c..d9a5d64c6ce2 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -32,6 +32,7 @@ TensorStructInfo, TupleStructInfo, ) +from tvm.relax.expr import Var from tvm.runtime import ObjectGeneric from tvm.tir import PrimExpr @@ -89,15 +90,21 @@ def __init__( dtype: Optional[str] = None, ndim: int = -1, ) -> None: + if isinstance(shape, Expr): + if isinstance(shape, Var): + if not isinstance(shape.struct_info, ShapeStructInfo): + raise ValueError( + "When the shape is a Var, it must have shape struct_info. But got " + f"{shape} with struct_info: {shape.struct_info}" + ) + elif not isinstance(shape, ShapeExpr): + raise ValueError( + "When the shape is an Expr, it must be a ShapeExpr or a Var with ShapeExpr " + f"value. But got: {shape} with type: {type(shape)}" + ) self.shape = shape - if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr): - raise ValueError( - "Only ShapeExpr is allowed as shape expr, but got: " - f"{shape} with type: {type(shape)}" - ) self.dtype = dtype self.ndim = ndim - super().__init__() def get_symbolic_vars(self) -> Set[str]: if self.shape is None or isinstance(self.shape, Expr): @@ -108,7 +115,7 @@ def get_symbolic_vars(self) -> Set[str]: def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: if self.shape is None: return TensorStructInfo(None, self.dtype, self.ndim) - elif isinstance(self.shape, ShapeExpr): + elif isinstance(self.shape, (ShapeExpr, Var)): return TensorStructInfo(self.shape, self.dtype, self.ndim) else: if dict_globals is None and any([isinstance(s, str) for s in self.shape]): @@ -121,23 +128,19 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso def Tensor( - shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> TensorProxy: # scalar tensor case - if shape is not None and len(shape) == 0: + if shape is not None and not isinstance(shape, Var) and len(shape) == 0: shape = [] if isinstance(shape, str) and dtype is None: dtype = shape shape = None - if ( - shape is not None - and not isinstance(shape, (tuple, list)) - and not isinstance(shape, ShapeExpr) - ): - raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}") + if shape is not None and not isinstance(shape, (tuple, list)) and not isinstance(shape, Expr): + raise ValueError(f"shape must be a list/tuple or an Expr, but got: {shape}") return TensorProxy(shape, dtype, ndim) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index e8c4c9de2aa3..9c8084e9d062 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -747,6 +747,7 @@ def foo( q: R.Tensor(ndim=2) = R.add(w, w) t = R.add(w, z) sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape) + lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh) o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object) return o @@ -759,13 +760,15 @@ def _check_struct_info(binding, expected_sinfo): assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo) m = relax.get_shape_of(foo.params[0])[1] bindings = foo.body.blocks[0].bindings + sh = bindings[4].var _check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32")) _check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1)) _check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2)) _check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1)) _check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1)) - _check_struct_info(bindings[5], relax.ObjectStructInfo()) + _check_struct_info(bindings[5], relax.TensorStructInfo(sh)) + _check_struct_info(bindings[6], relax.ObjectStructInfo()) def test_annotate_override(): From e619eba9269ddd0c399772d7973904022ba7591c Mon Sep 17 00:00:00 2001 From: ubospica Date: Sat, 25 Mar 2023 16:08:04 +0000 Subject: [PATCH 2/2] update --- python/tvm/script/parser/relax/entry.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index d9a5d64c6ce2..acb490a813b8 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -91,17 +91,16 @@ def __init__( ndim: int = -1, ) -> None: if isinstance(shape, Expr): - if isinstance(shape, Var): - if not isinstance(shape.struct_info, ShapeStructInfo): - raise ValueError( - "When the shape is a Var, it must have shape struct_info. But got " - f"{shape} with struct_info: {shape.struct_info}" - ) - elif not isinstance(shape, ShapeExpr): + if not isinstance(shape, (ShapeExpr, Var)): raise ValueError( "When the shape is an Expr, it must be a ShapeExpr or a Var with ShapeExpr " f"value. But got: {shape} with type: {type(shape)}" ) + if isinstance(shape, Var) and not isinstance(shape.struct_info, ShapeStructInfo): + raise ValueError( + "When the shape is a Var, it must have shape struct_info. But got " + f"{shape} with struct_info: {shape.struct_info}" + ) self.shape = shape self.dtype = dtype self.ndim = ndim