diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index d452af69d39f..1a7dcd6a648b 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -21,7 +21,7 @@ from tvm import tir from tvm.ir import IRModule -from ... import expr as rx +from .... import relax as rx from ...block_builder import BlockBuilder from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo from . import core, extern @@ -136,6 +136,8 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: outputs, inputs = _emit_method(self.builder, method_spec, params, effects) self.builder.emit_func_output(outputs, inputs) mod = self.builder.finalize() + assert rx.analysis.well_formed(mod) + return mod, params, ext_mods diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 29b9c7fcca48..0dd8fe92bd36 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -19,7 +19,7 @@ from typing import List, Optional, Sequence, Union from tvm import relax as rx -from tvm import tir +from tvm import tir, ir from . import op from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype @@ -517,8 +517,13 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg return [ bb.emit( rx.Call( - rx.extern("vm.builtin.attention_kv_cache_create"), - args=[rx.op.zeros(init_shape, self.dtype), init_shape, rx.PrimValue(0)], + ir.Op.get("relax.call_pure_packed"), + args=[ + rx.extern("vm.builtin.attention_kv_cache_create"), + rx.op.zeros(init_shape, self.dtype), + init_shape, + rx.PrimValue(0), + ], sinfo_args=[rx.ObjectStructInfo()], ), name_hint=name_hint, @@ -588,8 +593,12 @@ def view(self, seq_len: tir.Var) -> Tensor: return Tensor( _expr=rx.BlockBuilder.current().emit( rx.Call( - rx.extern("vm.builtin.attention_kv_cache_view"), - args=[self.cache, shape], + ir.Op.get("relax.call_pure_packed"), + args=[ + rx.extern("vm.builtin.attention_kv_cache_view"), + self.cache, + shape, + ], sinfo_args=[rx.TensorStructInfo(shape, self.dtype)], ) ) @@ -611,8 +620,12 @@ def append(self, new_element: Tensor) -> None: ) self.cache = rx.BlockBuilder.current().emit( rx.Call( - rx.extern("vm.builtin.attention_kv_cache_append"), - args=[self.cache, new_element._expr], + ir.Op.get("relax.call_pure_packed"), + args=[ + rx.extern("vm.builtin.attention_kv_cache_append"), + self.cache, + new_element._expr, + ], sinfo_args=[rx.ObjectStructInfo()], ) ) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index f438f387056c..931272f4b11f 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -451,15 +451,15 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): lv: R.Tensor((8, 2, 4), dtype="float32") = R.zeros( R.shape([8, 2, 4]), dtype="float32" ) - cache: R.Object = R.call_packed( + cache: R.Object = R.call_pure_packed( "vm.builtin.attention_kv_cache_create", lv, R.shape([8, 2, 4]), R.prim_value(0), sinfo_args=(R.Object,), ) - lv1: R.Tuple(R.Object, R.Object) = _io, cache - gv: R.Tuple(R.Object, R.Object) = lv1 + lv1 = _io, cache + gv = lv1 R.output(gv) return gv @@ -469,10 +469,10 @@ def forward( ) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)): R.func_attr({"num_input": 3}) with R.dataflow(): - lv2: R.Object = R.call_packed( + lv2: R.Object = R.call_pure_packed( "vm.builtin.attention_kv_cache_append", cache, x, sinfo_args=(R.Object,) ) - lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_packed( + lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", lv2, R.shape([4, 2, 4]),