Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TensorStructInfo,
TupleStructInfo,
)
from tvm.relax.expr import Var
from tvm.runtime import ObjectGeneric
from tvm.tir import PrimExpr

Expand Down Expand Up @@ -89,15 +90,20 @@ def __init__(
dtype: Optional[str] = None,
ndim: int = -1,
) -> None:
if isinstance(shape, Expr):
if not isinstance(shape, (ShapeExpr, Var)):
raise ValueError(
"When the shape is an Expr, it must be a ShapeExpr or a Var with ShapeExpr "
f"value. But got: {shape} with type: {type(shape)}"
)
if isinstance(shape, Var) and not isinstance(shape.struct_info, ShapeStructInfo):
raise ValueError(
"When the shape is a Var, it must have shape struct_info. But got "
f"{shape} with struct_info: {shape.struct_info}"
)
self.shape = shape
if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr):
raise ValueError(
"Only ShapeExpr is allowed as shape expr, but got: "
f"{shape} with type: {type(shape)}"
)
self.dtype = dtype
self.ndim = ndim
super().__init__()

def get_symbolic_vars(self) -> Set[str]:
if self.shape is None or isinstance(self.shape, Expr):
Expand All @@ -108,7 +114,7 @@ def get_symbolic_vars(self) -> Set[str]:
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo:
if self.shape is None:
return TensorStructInfo(None, self.dtype, self.ndim)
elif isinstance(self.shape, ShapeExpr):
elif isinstance(self.shape, (ShapeExpr, Var)):
return TensorStructInfo(self.shape, self.dtype, self.ndim)
else:
if dict_globals is None and any([isinstance(s, str) for s in self.shape]):
Expand All @@ -121,23 +127,19 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso


def Tensor(
shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> TensorProxy:
# scalar tensor case
if shape is not None and len(shape) == 0:
if shape is not None and not isinstance(shape, Var) and len(shape) == 0:
shape = []
if isinstance(shape, str) and dtype is None:
dtype = shape
shape = None

if (
shape is not None
and not isinstance(shape, (tuple, list))
and not isinstance(shape, ShapeExpr)
):
raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}")
if shape is not None and not isinstance(shape, (tuple, list)) and not isinstance(shape, Expr):
raise ValueError(f"shape must be a list/tuple or an Expr, but got: {shape}")
return TensorProxy(shape, dtype, ndim)


Expand Down
5 changes: 4 additions & 1 deletion tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def foo(
q: R.Tensor(ndim=2) = R.add(w, w)
t = R.add(w, z)
sh: R.Shape = R.call_packed("shape_of", x, sinfo_args=R.Shape)
lv: R.Tensor(sh, dtype="float32") = R.reshape(x, sh)
o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, sinfo_args=R.Object)
return o

Expand All @@ -759,13 +760,15 @@ def _check_struct_info(binding, expected_sinfo):
assert isinstance(foo.ret_struct_info, relax.ObjectStructInfo)
m = relax.get_shape_of(foo.params[0])[1]
bindings = foo.body.blocks[0].bindings
sh = bindings[4].var

_check_struct_info(bindings[0], relax.TensorStructInfo([32, m], "float32"))
_check_struct_info(bindings[1], relax.TensorStructInfo(dtype="", ndim=-1))
_check_struct_info(bindings[2], relax.TensorStructInfo(dtype="", ndim=2))
_check_struct_info(bindings[3], relax.TensorStructInfo(dtype="", ndim=-1))
_check_struct_info(bindings[4], relax.ShapeStructInfo(ndim=-1))
_check_struct_info(bindings[5], relax.ObjectStructInfo())
_check_struct_info(bindings[5], relax.TensorStructInfo(sh))
_check_struct_info(bindings[6], relax.ObjectStructInfo())


def test_annotate_override():
Expand Down