Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/tvm/relax/frontend/nn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm import tir
from tvm.ir import IRModule

from ... import expr as rx
from .... import relax as rx
from ...block_builder import BlockBuilder
from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo
from . import core, extern
Expand Down Expand Up @@ -136,6 +136,8 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
mod = self.builder.finalize()
assert rx.analysis.well_formed(mod)

return mod, params, ext_mods


Expand Down
27 changes: 20 additions & 7 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, Sequence, Union

from tvm import relax as rx
from tvm import tir
from tvm import tir, ir

from . import op
from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype
Expand Down Expand Up @@ -517,8 +517,13 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg
return [
bb.emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_create"),
args=[rx.op.zeros(init_shape, self.dtype), init_shape, rx.PrimValue(0)],
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_create"),
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
],
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=name_hint,
Expand Down Expand Up @@ -588,8 +593,12 @@ def view(self, seq_len: tir.Var) -> Tensor:
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_view"),
args=[self.cache, shape],
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_view"),
self.cache,
shape,
],
sinfo_args=[rx.TensorStructInfo(shape, self.dtype)],
)
)
Expand All @@ -611,8 +620,12 @@ def append(self, new_element: Tensor) -> None:
)
self.cache = rx.BlockBuilder.current().emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_append"),
args=[self.cache, new_element._expr],
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_append"),
self.cache,
new_element._expr,
],
sinfo_args=[rx.ObjectStructInfo()],
)
)
Expand Down
10 changes: 5 additions & 5 deletions tests/python/relax/test_frontend_nn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,15 +451,15 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object):
lv: R.Tensor((8, 2, 4), dtype="float32") = R.zeros(
R.shape([8, 2, 4]), dtype="float32"
)
cache: R.Object = R.call_packed(
cache: R.Object = R.call_pure_packed(
"vm.builtin.attention_kv_cache_create",
lv,
R.shape([8, 2, 4]),
R.prim_value(0),
sinfo_args=(R.Object,),
)
lv1: R.Tuple(R.Object, R.Object) = _io, cache
gv: R.Tuple(R.Object, R.Object) = lv1
lv1 = _io, cache
gv = lv1
R.output(gv)
return gv

Expand All @@ -469,10 +469,10 @@ def forward(
) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)):
R.func_attr({"num_input": 3})
with R.dataflow():
lv2: R.Object = R.call_packed(
lv2: R.Object = R.call_pure_packed(
"vm.builtin.attention_kv_cache_append", cache, x, sinfo_args=(R.Object,)
)
lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_packed(
lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
lv2,
R.shape([4, 2, 4]),
Expand Down