Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,192 @@ def __getitem__(self, index: int) -> "ExprWithOp":
raise IndexError from err
raise

def _check_for_tensor_struct_info(self):
"""Raise an error if this is something other than a Tensor

Used for early checks in `expr.dtype` and `expr.shape`
accessors. While invalid usage would cause errors to be
raised during shape inference, an earlier check makes it
easier to find the invalid usage.
"""
if self.struct_info_ is None:
return

if not isinstance(self.struct_info_, tvm.relax.TensorStructInfo):
raise TypeError(
f"Runtime unpacking of DLDataType is only implemented for tensors, "
f"but was applied to object {self} of type {type(self)}."
)

@property
def dtype(self) -> "_DLTensorDTypeProxy":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the return type have to be in quotes? I assume it has to do with the property decorator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More the order of definitions in the file. Function annotations are evaluated when the class is being defined. Since the _DLTensorDTypeProxy class is defined lower in the file, the type annotations are provided as a string, rather than as a class object.

"""Returns a proxy object for accessing DLTensor::dtype"""
self._check_for_tensor_struct_info()
return _DLTensorDTypeProxy(self)

@property
def ndim(self) -> "Expr":
"""Returns the runtime value of DLTensor::ndim"""
self._check_for_tensor_struct_info()
op = tvm.ir.Op.get("relax.inspect.tensor_ndim")
return tvm.relax.Call(op, [self])

@property
def shape(self) -> "_DLTensorShapeProxy":
"""Returns a proxy object for accessing DLTensor::shape"""
self._check_for_tensor_struct_info()
return _DLTensorShapeProxy(self)


class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
"""A proxy object for unpacking DLDatatype from DLTensor

Exposes accessors for `DLDataType` fields `type_code`, `lanes`,
and `bits` within a `DLTensor::dtype`. Accessing these fields
Comment on lines +287 to +288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are good to have. Offset might also be useful to add, as it might help for memory reuse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. At the moment, I stuck with the values that have a direct presence elsewhere in Relax, but I agree that it would be good to be able to extract any of the DLTensor* fields.

will produce `relax.Call` expressions, representing the field's
runtime value. If the datatype of the tensor is known at
compile-time, the `relax.Call` will be normalized into a
`relax.PrimValue`, with no runtime cost.

Parameters
----------
tensor: relax.Expr

The relax tensor (or a variable referring to a relax tensor),
whose runtime shape is being inspected.

"""

def __init__(self, tensor):
self.tensor = tensor

def asobject(self):
"""Provide expected in error message

This method is called when `_DLTensorDTypeProxy` is used in a
context that requires a `relax.Expr`. This usage is not
supported, and raising an error here can provide suggested
fixes that are not present in the default error message from
`tvm.runtime.convert_to_object`.
"""

fields = [f"{self.tensor}.dtype.{field}" for field in ["type_code", "bits", "lanes"]]
raise TypeError(
f"{self.tensor}.dtype cannot be converted to a relax expression, "
f"and should be used as a proxy object to access "
f"fields {fields}"
)

@property
def type_code(self) -> Expr:
"""Accessor for the DLDataType::bits field

Returns
-------
type_code: Expr

The type code of the DLTensor. See the `DLDeviceType`
enum in `dlpack.h` for more information.
"""
op = tvm.ir.Op.get("relax.inspect.tensor_dtype_code")
return tvm.relax.Call(op, [self.tensor])

@property
def lanes(self) -> Expr:
"""Accessor for the DLDataType::bits field

Returns
-------
lanes: Expr

The number of lanes in the DLDataType
"""
op = tvm.ir.Op.get("relax.inspect.tensor_dtype_lanes")
return tvm.relax.Call(op, [self.tensor])

@property
def bits(self) -> Expr:
"""Accessor for the DLDataType::bits field

Returns
-------
bits: Expr

The number of bits in the DLDataType
"""
op = tvm.ir.Op.get("relax.inspect.tensor_dtype_bits")
return tvm.relax.Call(op, [self.tensor])


class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric):
"""A proxy object for unpacking the shape from DLTensor

Exposes accessors for the `DLTensor::shape` field. Accessing
these fields will produce `relax.Call` expressions, representing
the field's runtime value. If the datatype of the tensor is known
at compile-time, the `relax.Call` will be normalized into a
`relax.PrimValue`, with no runtime cost.

Parameters
----------
tensor: relax.Expr

The relax tensor (or a variable referring to a relax tensor),
whose runtime shape is being inspected.
"""

def __init__(self, tensor):
self.tensor = tensor

def asobject(self):
"""Provide expected in error message

This method is called when `_DLTensorShapeProxy` is used in a
context that requires a `relax.Expr`. This usage is not
supported, and raising an error here can provide suggested
fixes that are not present in the default error message from
`tvm.runtime.convert_to_object`.
"""
raise TypeError(
f"{self.tensor}.shape cannot be converted to a relax expression, "
f"and should be used as a proxy object to access the runtime shape of the DLTensor. "
f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
f"and the DLTensor::shape array can be accessed as {self.tensor}.shape[i]"
)

def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
"""Returns the extent of a tensor axis

Parameters
----------
axis: Union[int, PrimExpr, Expr]

The tensor axis whose extent should be returned. For ease
of use, any python integers or TIR expressions are
converted to `relax.Expr`.

Returns
-------
extent: Expr

The extent of the tensor's axis.
"""

if not isinstance(axis, tvm.relax.Expr):
axis = tvm.relax.PrimValue(axis)

if axis.struct_info_ is not None and not isinstance(
axis.struct_info_, tvm.relax.PrimStructInfo
):
raise TypeError(
f"The index used to access {self.tensor}.shape "
f'must have struct info R.Prim("int64"), '
f"but index {axis} had struct info {axis.struct_info_}."
)

op = tvm.ir.Op.get("relax.inspect.tensor_shape_i")
return tvm.relax.Call(op, [self.tensor, axis])


@tvm._ffi.register_object("relax.expr.Call")
class Call(ExprWithOp):
Expand Down
Loading