Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..runtime import String, convert_to_object
from ..tir import PrimExpr
from . import _ffi_api
from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm
from .expr import Expr, Function, PrimValue, StringImm
from .expr import Tuple as rx_Tuple


Expand Down Expand Up @@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr:
1. Return the input itself if it's already a `relax.Expr`;
2. Return `relax.PrimValue` if the input is a `PrimExpr`;
3. Return `relax.StringImm` if the input is `tvm.String` or `str`;
4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype;
5. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
4. Return `relax.Tuple` if the input is a tuple/list of `Expr`.

Notes
-----
1. `tvm.tir.StringImm` is not allowed because of ambiguity,
which can be either `relax.StringImm` or `relax.PrimValue`.
2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr`
"""
if isinstance(value, int):
return PrimValue(tir.IntImm("int64", value))
Expand All @@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr:
# Case 3
if isinstance(tvm_value, String):
return StringImm(value)
# Case 4 & 5
# Case 4
if isinstance(value, (tuple, list)):
# Note 2
if len(value) == 0:
return rx_Tuple([])
# Case 4
opt_prim_value = [convert_to_object(v) for v in value]
if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]):
return ShapeExpr(value)
# Case 5
# `convert_to_expr` ensures that all elements are `Expr` if no exception raises
return rx_Tuple([convert_to_expr(v) for v in value])
raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`")
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,23 @@ def tuple(*fields: Expr) -> Expr:
return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member


############################### R.shape ################################


def shape(value: List[PrimExpr]) -> Expr:
"""Create a ShapeExpr.
Parameters
----------
value : List[PrimExpr]
The fields of the tuple.
Returns
-------
res : Expr
The result tuple.
"""
return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore


############################### PrimValue ##############################


Expand Down Expand Up @@ -407,6 +424,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"prim_value",
"print",
"reshape",
"shape",
"shape_of",
"str",
"tuple",
Expand Down
22 changes: 17 additions & 5 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tvm.relax import (
Expr,
ShapeExpr,
FuncStructInfo,
Function,
ObjectStructInfo,
Expand Down Expand Up @@ -84,24 +85,31 @@ class TensorProxy(StructInfoProxy):

def __init__(
self,
shape: Optional[List[Union[PrimExpr, str]]] = None,
shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> None:
self.shape = shape
if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr):
raise ValueError(
"Only ShapeExpr is allowed as shape expr, but got: "
f"{shape} with type: {type(shape)}"
)
self.dtype = dtype
self.ndim = ndim
super().__init__()

def get_symbolic_vars(self) -> Set[str]:
if self.shape is None:
if self.shape is None or isinstance(self.shape, Expr):
return {}
else:
return {s for s in self.shape if isinstance(s, str) and s.isidentifier()}

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo:
if self.shape is None:
return TensorStructInfo(None, self.dtype, self.ndim)
elif isinstance(self.shape, ShapeExpr):
return TensorStructInfo(self.shape, self.dtype, self.ndim)
else:
if dict_globals is None and any([isinstance(s, str) for s in self.shape]):
raise ValueError(
Expand All @@ -113,7 +121,7 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso


def Tensor(
shape: Optional[List[Union[PrimExpr, str]]] = None,
shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> TensorProxy:
Expand All @@ -124,8 +132,12 @@ def Tensor(
dtype = shape
shape = None

if shape is not None and not isinstance(shape, (tuple, list)):
raise ValueError(f"shape must be a list or tuple, but got: {shape}")
if (
shape is not None
and not isinstance(shape, (tuple, list))
and not isinstance(shape, ShapeExpr)
):
raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}")
return TensorProxy(shape, dtype, ndim)


Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/relax/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
for (int i = 0, l = n->values.size(); i < l; ++i) {
values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d));
}
return TupleDoc(values_doc);
return Relax(d, "shape")->Call({ListDoc(values_doc)});
});

Optional<ExprDoc> SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) {
Expand Down
14 changes: 13 additions & 1 deletion src/script/printer/relax/struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Array<String> kwargs_keys;
Array<ExprDoc> kwargs_values;
if (n->shape.defined()) {
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
// Need to dig into ShapeExpr to preserve the `R.shape` prefix
if (const auto* shape = n->shape.value().as<relax::ShapeExprNode>()) {
auto shape_expr = GetRef<relax::ShapeExpr>(shape);
ObjectPath shape_p = n_p->Attr("shape")->Attr("values");
Array<ExprDoc> shape_docs;
for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) {
shape_docs.push_back(
PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d));
}
args.push_back(TupleDoc(shape_docs));
} else {
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
}
}
if (!n->IsUnknownDtype()) {
kwargs_keys.push_back("dtype");
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_backend_transform_shape_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def main(
n = T.Var("n", "int64")
k = T.Var("k", "int64")
z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
return (k + 1, m, 2)
return R.shape([k + 1, m, 2])

# slot assignment:
# 0: n, 1: m, 2:k, 3: k+1
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class TestVMBuiltinLower:
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
m, n = T.var("int64"), T.var("int64")
alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32")
alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32")
_ = R.call_packed(
"test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))
)
Expand Down
36 changes: 26 additions & 10 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
import tvm.script
import tvm.testing
from tvm import IRModule, relax, tir, topi
from tvm.relax import DynTensorType
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T


def _check(
Expand Down Expand Up @@ -202,6 +201,23 @@ def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):
_check(foo, bb.get()["foo"])


def test_relax_base_op():
@R.function
def foo(x: R.Tensor((4, 4), "float32")):
alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32")
shape = R.shape_of(alloc)
return shape

x = relax.Var("x", R.Tensor((4, 4), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", (x,)):
alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0))
shape = bb.emit(relax.op.shape_of(alloc))
bb.emit_func_output(shape)
# todo(yongwww): comment this check because 0 was changed to R.prim_value(0) in the printed IR
# _check(foo, bb.get()["foo"])


def test_symbolic_shape():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"):
Expand Down Expand Up @@ -274,7 +290,7 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")):
y0 = R.match_cast(y, R.Tensor([n], "float32"))
gv = y0
R.output(gv)
return (x0, (m, n * 2))
return (x0, R.shape([m, n * 2]))

x = relax.Var("x", R.Tensor("float32"))
y = relax.Var("y", R.Tensor("float32"))
Expand Down Expand Up @@ -314,7 +330,7 @@ def test_tuple_return_2():
def foo(x: R.Tensor("float32", ndim=2)):
n, m = T.var("int64"), T.var("int64")
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
return (x0, (n + 1, m, 1))
return (x0, R.shape([n + 1, m, 1]))

x = relax.Var("x", R.Tensor("float32", ndim=2))
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
Expand All @@ -332,7 +348,7 @@ def foo(x: R.Tensor("float32", ndim=2)):
n, m = T.var("int64"), T.var("int64")
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
t0 = (x, x0)
t1 = (x, (n, m), t0)
t1 = (x, R.shape([n, m]), t0)
return t1

x = relax.Var("x", R.Tensor("float32", ndim=2))
Expand Down Expand Up @@ -965,9 +981,9 @@ def test_vm_ops():
def foo(x: R.Tensor(("m", "n"), dtype="float32")):
m = T.var("int64")
n = T.var("int64")
storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0)
alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32")
tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0)
storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", runtime_device_index=0)
alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, dtype="float32")
tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0)
_ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n)))
gv = tensor
return alloc, gv
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_tuple_get_item():

def test_shape_expr():
obj = relax.ShapeExpr([1, 2, 3])
_assert_print(obj, "(1, 2, 3)")
_assert_print(obj, "R.shape([1, 2, 3])")


def test_call():
Expand All @@ -304,7 +304,7 @@ def test_call():
"""
x = T.Var("x", "int64")
a: R.Tensor((1, x, 3), dtype="float32")
R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,))
R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x]))
""",
)

Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TestVMCompileStage2:
def foo(x: R.Tensor(dtype="float32")) -> R.Shape:
n, m = T.var("int64"), T.var("int64")
_ = R.match_cast(x, R.Tensor((n, m), "float32"))
return (n * 2, m * 3)
return R.shape([n * 2, m * 3])

mod = TestVMCompileStage2
target = tvm.target.Target("llvm", host="llvm")
Expand Down Expand Up @@ -511,9 +511,9 @@ class TestMemoryAllocStorageTensor:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32")):
storage = R.memory.alloc_storage(
(24,), virtual_device_index=0, storage_scope="global", dtype="float32"
R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32"
)
y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32")
y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32")
_ = copy(x, y)
return y

Expand Down
14 changes: 8 additions & 6 deletions tests/python/relax/test_vm_codegen_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

Restrictions: all shape lowered, explicit allocation.
"""
import tvm
import pytest
import numpy as np
from tvm import relax, TVMError
from tvm.script import relax as R, tir as T
import pytest
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode
from tvm.relax.testing.vm import check_saved_func
from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode
from tvm.script import relax as R
from tvm.script import tir as T

EXEC_MODE = ["bytecode"]

Expand Down Expand Up @@ -312,7 +314,7 @@ class TestVMBuiltinReshape:
def main(x: R.Tensor((3, 4), "float32")):
R.func_attr({"global_symbol": "main"})
y = R.call_packed(
"vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32")
"vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32")
)
return y

Expand Down