Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,21 +1108,26 @@ def inline_functions(


@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
class ExternFunc(BaseFunc, ExprWithOp):
"""extern function, which represents a PackedFunc."""

global_symbol: String
span: Optional[Span]

def __init__(self, global_symbol: String, span: Optional[Span] = None) -> None:
def __init__(
self,
global_symbol: String,
struct_info: Optional[StructInfo] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ExternFunc, global_symbol, span # type: ignore
_ffi_api.ExternFunc, global_symbol, struct_info, span # type: ignore
)


def extern(name: str, span: Optional[Span] = None):
def extern(name: str, struct_info: Optional[StructInfo] = None, span: Optional[Span] = None):
"""Create extern function."""
return ExternFunc(name, span)
return ExternFunc(name, struct_info, span)


def const(
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
"""Relax memory primitives."""

from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
from .view import view
94 changes: 94 additions & 0 deletions python/tvm/relax/op/memory/view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Operations that act on the DLTensor container

While most operations require inspecting the values stored within the
allocated buffers, some operations only require updating the fields in
a `DLTensor`, without touching the values that are stored within it.
For example, given an array of shape `[16,16]`, the slice at
`[0:8,0:16]` can be generated by changing the `DLTensor::shape` field,
while keeping the same underlying data.

"""
from typing import Optional, Sequence, Union

from tvm.tir import PrimExpr
from tvm.relax import Expr, ShapeExpr, DataTypeImm, PrimValue

from . import _ffi_api


PrimExprLike = Union[int, PrimExpr]


def view(
data: Expr,
shape: Optional[Union[Sequence[PrimExprLike], Expr]] = None,
dtype: Optional[Expr] = None,
relative_byte_offset: Optional[Expr] = None,
) -> Expr:
"""Provide a view into an existing tensor

The view may have a different shape, may be a different datatype,
and may start at an offset relative to the source array.

Regardless of which combination of these options are used, the
view may never access memory that was not accessible through the
input `data` array. This restriction applies even if the `data`
array is itself a view into a shared backing array.

Parameters
----------
data : relax.Expr

The input data to the operator.

shape : Optional[Union[Sequence[PrimExprLike], Expr]]

The target shape. Should be a `relax.ShapeExpr`, or a
collection that can be converted to a `relax.ShapeExpr`.

dtype : Optional[Expr]

The target datatype. Should be a `relax.ShapeExpr`, or a
collection that can be converted to a `relax.ShapeExpr`.

relative_byte_offset: Optional[Expr]

The offset of the output NDArray, relative to the byte offset
of `data`. If `None`, the offset of the view is the same as
the offset of `data`.

Returns
-------
result : relax.Expr
The tensor view

"""

def _normalize(expr, relax_cls):
if expr is None or isinstance(expr, Expr):
return expr
else:
return relax_cls(expr)

shape = _normalize(shape, ShapeExpr)
dtype = _normalize(dtype, DataTypeImm)
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)

return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore
7 changes: 5 additions & 2 deletions python/tvm/relax/struct_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(
def opaque_func(
*,
ret: Optional[StructInfo] = None,
derive_func: Optional[EnvFunc] = None,
derive_func: Optional[Union[str, EnvFunc]] = None,
purity: bool = False,
span: Span = None,
) -> "FuncStructInfo":
Expand All @@ -249,7 +249,7 @@ def opaque_func(
ret: Optional[StructInfo]
The struct info of the function return value.

derive_func: Optional[EnvFunc]
derive_func: Optional[Union[str,EnvFunc]]
The environment function used for derivation

purity: bool
Expand All @@ -266,4 +266,7 @@ def opaque_func(
----
We cannot specify ret and derive_func simultaneously.
"""

if isinstance(derive_func, str):
derive_func = tvm.ir.EnvFunc.get("tvm.relax.struct_info.infer_view_sinfo")
return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, purity, span) # type: ignore
18 changes: 16 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Callable as _Callable
from typing import Dict, List, Optional, Set, TypeVar, Union

import tvm
from tvm.relax import (
Expr,
SeqExpr,
Expand Down Expand Up @@ -277,6 +278,7 @@ class CallableProxy(StructInfoProxy):
params: List[StructInfoProxy]
ret: StructInfoProxy
purity: bool
derive_func: Optional[Union[str, tvm.ir.EnvFunc]]

"""Function type.

Expand All @@ -296,13 +298,21 @@ class CallableProxy(StructInfoProxy):
purity : bool
Whether the callable is pure.

derive_func: Optional[Union[str, tvm.ir.EnvFunc]]
The derivation function to determine the output StructInfo,
based on the arguments provided to the function. The
specified function should be accessible using
`tvm.get_global_func`, and should have a signature
`Callable[[relax.Call, relax.BlockBuilder], relax.StructInfo]`.

"""

def __init__(
self,
params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
ret: Optional[StructInfoProxy] = None,
purity: Optional[bool] = None,
derive_func: Optional[Union[str, tvm.ir.EnvFunc]] = None,
) -> None:
if params is None:
self.params = params
Expand All @@ -320,6 +330,7 @@ def __init__(

self.ret = ret() if callable(ret) else ret
self.purity = purity
self.derive_func = derive_func

def get_symbolic_vars(self) -> Set[str]:
if self.params is None:
Expand All @@ -339,7 +350,9 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncS
params = [param.as_struct_info(dict_globals) for param in self.params]

if params is None:
return FuncStructInfo.opaque_func(ret=ret, purity=self.purity)
return FuncStructInfo.opaque_func(
ret=ret, derive_func=self.derive_func, purity=self.purity
)
else:
return FuncStructInfo(params, ret, purity=self.purity)

Expand All @@ -348,8 +361,9 @@ def Callable(
params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
ret: Optional[StructInfoProxy] = None,
purity: Optional[bool] = None,
derive_func: Optional[Union[str, tvm.ir.EnvFunc]] = None,
) -> CallableProxy:
return CallableProxy(params, ret, purity=purity)
return CallableProxy(params, ret, purity=purity, derive_func=derive_func)


############################### R.Tuple ################################
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St
struct_info = self.eval_expr(node)
return _normalize_struct_info(struct_info, var_table)
except Exception as err:
self.report_error(node, str(err))
self.report_error(node, err)
raise err


Expand Down
11 changes: 8 additions & 3 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,14 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span)
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) {
return ExternFunc(global_symbol, span);
});
TVM_REGISTER_GLOBAL("relax.ExternFunc")
.set_body_typed([](String global_symbol, Optional<StructInfo> struct_info, Span span) {
if (struct_info.defined()) {
return ExternFunc(global_symbol, struct_info.value(), span);
} else {
return ExternFunc(global_symbol, span);
}
});

Expr GetShapeOf(const Expr& expr) {
// default case, to be normalized.
Expand Down
Loading