-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax] Implement operators to read runtime DLTensor* information #16563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -244,6 +244,192 @@ def __getitem__(self, index: int) -> "ExprWithOp": | |
| raise IndexError from err | ||
| raise | ||
|
|
||
| def _check_for_tensor_struct_info(self): | ||
| """Raise an error if this is something other than a Tensor | ||
|
|
||
| Used for early checks in `expr.dtype` and `expr.shape` | ||
| accessors. While invalid usage would cause errors to be | ||
| raised during shape inference, an earlier check makes it | ||
| easier to find the invalid usage. | ||
| """ | ||
| if self.struct_info_ is None: | ||
| return | ||
|
|
||
| if not isinstance(self.struct_info_, tvm.relax.TensorStructInfo): | ||
| raise TypeError( | ||
| f"Runtime unpacking of DLDataType is only implemented for tensors, " | ||
| f"but was applied to object {self} of type {type(self)}." | ||
| ) | ||
|
|
||
| @property | ||
| def dtype(self) -> "_DLTensorDTypeProxy": | ||
| """Returns a proxy object for accessing DLTensor::dtype""" | ||
| self._check_for_tensor_struct_info() | ||
| return _DLTensorDTypeProxy(self) | ||
|
|
||
| @property | ||
| def ndim(self) -> "Expr": | ||
| """Returns the runtime value of DLTensor::ndim""" | ||
| self._check_for_tensor_struct_info() | ||
| op = tvm.ir.Op.get("relax.inspect.tensor_ndim") | ||
| return tvm.relax.Call(op, [self]) | ||
|
|
||
| @property | ||
| def shape(self) -> "_DLTensorShapeProxy": | ||
| """Returns a proxy object for accessing DLTensor::shape""" | ||
| self._check_for_tensor_struct_info() | ||
| return _DLTensorShapeProxy(self) | ||
|
|
||
|
|
||
| class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric): | ||
| """A proxy object for unpacking DLDatatype from DLTensor | ||
|
|
||
| Exposes accessors for `DLDataType` fields `type_code`, `lanes`, | ||
| and `bits` within a `DLTensor::dtype`. Accessing these fields | ||
|
Comment on lines
+287
to
+288
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are good to have. Offset might also be useful to add, as it might help for memory reuse.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. At the moment, I stuck with the values that have a direct presence elsewhere in Relax, but I agree that it would be good to be able to extract any of the |
||
| will produce `relax.Call` expressions, representing the field's | ||
| runtime value. If the datatype of the tensor is known at | ||
| compile-time, the `relax.Call` will be normalized into a | ||
| `relax.PrimValue`, with no runtime cost. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| tensor: relax.Expr | ||
|
|
||
| The relax tensor (or a variable referring to a relax tensor), | ||
| whose runtime shape is being inspected. | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, tensor): | ||
| self.tensor = tensor | ||
|
|
||
| def asobject(self): | ||
| """Provide expected in error message | ||
|
|
||
| This method is called when `_DLTensorDTypeProxy` is used in a | ||
| context that requires a `relax.Expr`. This usage is not | ||
| supported, and raising an error here can provide suggested | ||
| fixes that are not present in the default error message from | ||
| `tvm.runtime.convert_to_object`. | ||
| """ | ||
|
|
||
| fields = [f"{self.tensor}.dtype.{field}" for field in ["type_code", "bits", "lanes"]] | ||
| raise TypeError( | ||
| f"{self.tensor}.dtype cannot be converted to a relax expression, " | ||
| f"and should be used as a proxy object to access " | ||
| f"fields {fields}" | ||
| ) | ||
|
|
||
| @property | ||
| def type_code(self) -> Expr: | ||
| """Accessor for the DLDataType::bits field | ||
|
|
||
| Returns | ||
| ------- | ||
| type_code: Expr | ||
|
|
||
| The type code of the DLTensor. See the `DLDeviceType` | ||
| enum in `dlpack.h` for more information. | ||
| """ | ||
| op = tvm.ir.Op.get("relax.inspect.tensor_dtype_code") | ||
| return tvm.relax.Call(op, [self.tensor]) | ||
|
|
||
| @property | ||
| def lanes(self) -> Expr: | ||
| """Accessor for the DLDataType::bits field | ||
|
|
||
| Returns | ||
| ------- | ||
| lanes: Expr | ||
|
|
||
| The number of lanes in the DLDataType | ||
| """ | ||
| op = tvm.ir.Op.get("relax.inspect.tensor_dtype_lanes") | ||
| return tvm.relax.Call(op, [self.tensor]) | ||
|
|
||
| @property | ||
| def bits(self) -> Expr: | ||
| """Accessor for the DLDataType::bits field | ||
|
|
||
| Returns | ||
| ------- | ||
| bits: Expr | ||
|
|
||
| The number of bits in the DLDataType | ||
| """ | ||
| op = tvm.ir.Op.get("relax.inspect.tensor_dtype_bits") | ||
| return tvm.relax.Call(op, [self.tensor]) | ||
|
|
||
|
|
||
| class _DLTensorShapeProxy(tvm.runtime.ObjectGeneric): | ||
| """A proxy object for unpacking the shape from DLTensor | ||
|
|
||
| Exposes accessors for the `DLTensor::shape` field. Accessing | ||
| these fields will produce `relax.Call` expressions, representing | ||
| the field's runtime value. If the datatype of the tensor is known | ||
| at compile-time, the `relax.Call` will be normalized into a | ||
| `relax.PrimValue`, with no runtime cost. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| tensor: relax.Expr | ||
|
|
||
| The relax tensor (or a variable referring to a relax tensor), | ||
| whose runtime shape is being inspected. | ||
| """ | ||
|
|
||
| def __init__(self, tensor): | ||
| self.tensor = tensor | ||
|
|
||
| def asobject(self): | ||
| """Provide expected in error message | ||
|
|
||
| This method is called when `_DLTensorShapeProxy` is used in a | ||
| context that requires a `relax.Expr`. This usage is not | ||
| supported, and raising an error here can provide suggested | ||
| fixes that are not present in the default error message from | ||
| `tvm.runtime.convert_to_object`. | ||
| """ | ||
| raise TypeError( | ||
| f"{self.tensor}.shape cannot be converted to a relax expression, " | ||
| f"and should be used as a proxy object to access the runtime shape of the DLTensor. " | ||
| f"The DLTensor::ndim field can be accessed as len({self.tensor}), " | ||
| f"and the DLTensor::shape array can be accessed as {self.tensor}.shape[i]" | ||
| ) | ||
|
|
||
| def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr: | ||
| """Returns the extent of a tensor axis | ||
|
|
||
| Parameters | ||
| ---------- | ||
| axis: Union[int, PrimExpr, Expr] | ||
|
|
||
| The tensor axis whose extent should be returned. For ease | ||
| of use, any python integers or TIR expressions are | ||
| converted to `relax.Expr`. | ||
|
|
||
| Returns | ||
| ------- | ||
| extent: Expr | ||
|
|
||
| The extent of the tensor's axis. | ||
| """ | ||
|
|
||
| if not isinstance(axis, tvm.relax.Expr): | ||
| axis = tvm.relax.PrimValue(axis) | ||
|
|
||
| if axis.struct_info_ is not None and not isinstance( | ||
| axis.struct_info_, tvm.relax.PrimStructInfo | ||
| ): | ||
| raise TypeError( | ||
| f"The index used to access {self.tensor}.shape " | ||
| f'must have struct info R.Prim("int64"), ' | ||
| f"but index {axis} had struct info {axis.struct_info_}." | ||
| ) | ||
|
|
||
| op = tvm.ir.Op.get("relax.inspect.tensor_shape_i") | ||
| return tvm.relax.Call(op, [self.tensor, axis]) | ||
|
|
||
|
|
||
| @tvm._ffi.register_object("relax.expr.Call") | ||
| class Call(ExprWithOp): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the return type have to be in quotes? I assume it has to do with the property decorator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More the order of definitions in the file. Function annotations are evaluated when the class is being defined. Since the
_DLTensorDTypeProxyclass is defined lower in the file, the type annotations are provided as a string, rather than as a class object.