From 84acf6830ac0869599a8f9211a1058d5e811ae50 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 4 Mar 2024 21:31:00 +0000 Subject: [PATCH 1/3] [Bugfix][SLM] Produce well-formed Relax for nn.modules.KVCache Prior to this commit, the `nn.modules.KVCache` implementations used `R.call_packed(...)` to call the `"vm.builtin.attention_*"` functions. Since `nn.Module` emits all relax functions within a `relax.DataflowBlock`, where impure expressions are forbidden, this is ill-formed. This commit updates the implementations in `nn.modules.KVCache` to use `R.call_pure_packed` instead of `R.call_packed`. This assertation that the callee is pure allows the call to occur within a `relax.DataflowBlock`. --- python/tvm/relax/frontend/nn/exporter.py | 2 ++ python/tvm/relax/frontend/nn/modules.py | 27 ++++++++++++++----- .../python/relax/test_frontend_nn_modules.py | 9 +++---- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index d452af69d39f..443ecbad735c 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -136,6 +136,8 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: outputs, inputs = _emit_method(self.builder, method_spec, params, effects) self.builder.emit_func_output(outputs, inputs) mod = self.builder.finalize() + assert rx.analysis.well_formed(mod) + return mod, params, ext_mods diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 29b9c7fcca48..0dd8fe92bd36 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -19,7 +19,7 @@ from typing import List, Optional, Sequence, Union from tvm import relax as rx -from tvm import tir +from tvm import tir, ir from . import op from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype @@ -517,8 +517,13 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg return [ bb.emit( rx.Call( - rx.extern("vm.builtin.attention_kv_cache_create"), - args=[rx.op.zeros(init_shape, self.dtype), init_shape, rx.PrimValue(0)], + ir.Op.get("relax.call_pure_packed"), + args=[ + rx.extern("vm.builtin.attention_kv_cache_create"), + rx.op.zeros(init_shape, self.dtype), + init_shape, + rx.PrimValue(0), + ], sinfo_args=[rx.ObjectStructInfo()], ), name_hint=name_hint, @@ -588,8 +593,12 @@ def view(self, seq_len: tir.Var) -> Tensor: return Tensor( _expr=rx.BlockBuilder.current().emit( rx.Call( - rx.extern("vm.builtin.attention_kv_cache_view"), - args=[self.cache, shape], + ir.Op.get("relax.call_pure_packed"), + args=[ + rx.extern("vm.builtin.attention_kv_cache_view"), + self.cache, + shape, + ], sinfo_args=[rx.TensorStructInfo(shape, self.dtype)], ) ) @@ -611,8 +620,12 @@ def append(self, new_element: Tensor) -> None: ) self.cache = rx.BlockBuilder.current().emit( rx.Call( - rx.extern("vm.builtin.attention_kv_cache_append"), - args=[self.cache, new_element._expr], + ir.Op.get("relax.call_pure_packed"), + args=[ + rx.extern("vm.builtin.attention_kv_cache_append"), + self.cache, + new_element._expr, + ], sinfo_args=[rx.ObjectStructInfo()], ) ) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index f438f387056c..6e54d28da096 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -451,15 +451,14 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): lv: R.Tensor((8, 2, 4), dtype="float32") = R.zeros( R.shape([8, 2, 4]), dtype="float32" ) - cache: R.Object = R.call_packed( + cache: R.Object = R.call_pure_packed( "vm.builtin.attention_kv_cache_create", lv, R.shape([8, 2, 4]), R.prim_value(0), sinfo_args=(R.Object,), ) - lv1: R.Tuple(R.Object, R.Object) = _io, cache - gv: R.Tuple(R.Object, R.Object) = lv1 + gv: R.Tuple(R.Object, R.Object) = _io, cache R.output(gv) return gv @@ -469,10 +468,10 @@ def forward( ) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)): R.func_attr({"num_input": 3}) with R.dataflow(): - lv2: R.Object = R.call_packed( + lv2: R.Object = R.call_pure_packed( "vm.builtin.attention_kv_cache_append", cache, x, sinfo_args=(R.Object,) ) - lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_packed( + lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", lv2, R.shape([4, 2, 4]), From b8b7afc00c20235ea46a3ca1556107b1282f9bae Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 6 Mar 2024 18:32:58 +0000 Subject: [PATCH 2/3] Correct import for relax --- python/tvm/relax/frontend/nn/exporter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 443ecbad735c..1a7dcd6a648b 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -21,7 +21,7 @@ from tvm import tir from tvm.ir import IRModule -from ... import expr as rx +from .... import relax as rx from ...block_builder import BlockBuilder from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo from . import core, extern From eaafc1eec514caa0ec8ce5b7f610d4e101da3394 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 8 Mar 2024 11:23:36 -0600 Subject: [PATCH 3/3] Fix unit test --- tests/python/relax/test_frontend_nn_modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 6e54d28da096..931272f4b11f 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -458,7 +458,8 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): R.prim_value(0), sinfo_args=(R.Object,), ) - gv: R.Tuple(R.Object, R.Object) = _io, cache + lv1 = _io, cache + gv = lv1 R.output(gv) return gv