diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py index 43bbd461bca8..9dfb30bc7487 100644 --- a/python/tvm/relax/op/builtin/builtin.py +++ b/python/tvm/relax/op/builtin/builtin.py @@ -50,3 +50,21 @@ def alloc_tensor( runtime_device_index = PrimValue(runtime_device_index) return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore + + +def stop_lift_params(x: Expr) -> Expr: + """ + An indicator that the consumers of input tensor should not be + lifted to transform_params function + + Parameters + ---------- + x: relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The result tensor that is the same as input tensor + """ + return _ffi_api.stop_lift_params(x) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d3448916091d..f9104c94305a 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -129,6 +129,7 @@ zeros_like, nn, ) +from tvm.relax.op.builtin import stop_lift_params from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter from tvm.runtime import Object as tvm_Object @@ -634,6 +635,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "square", "squeeze", "sqrt", + "stop_lift_params", "str", "strided_slice", "subtract", diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 49df881dcb8b..b353cce27a73 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -513,5 +513,22 @@ Expr MakeCallTIRDyn(Expr func, Tuple args) { TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); +// builtin stop_lift_params +StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnaryArith(call, ctx); +} + +RELAY_REGISTER_OP("relax.builtin.stop_lift_params") + .set_num_inputs(1) + .add_argument("x", "Expr", "The input data") + .set_attr("FInferStructInfo", InferStructInfoStopLiftParams); + +Expr MakeStopLiftParams(Expr x) { + static const Op& op = Op::Get("relax.builtin.stop_lift_params"); + return Call(op, {x}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 401a03dbe25f..88939bd1f5ea 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -168,6 +168,12 @@ class LiftTransformParamsPlanner : public ExprVisitor { if (!is_in_dataflow_block_) { can_lift = false; } + if (const auto* call = binding->value.as()) { + static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params"); + if (call->op.same_as(stop_lift_params_op)) { + can_lift = false; + } + } PostOrderVisit(binding->value, [&](const ObjectRef& obj) { if (const VarNode* var = obj.as()) { @@ -268,6 +274,13 @@ class TransformParamsLifter : public ExprMutator { if (lift_plan_.lifted_bindings.count(binding->var)) { return; } + if (const auto* call = binding->value.as()) { + static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params"); + if (call->op.same_as(stop_lift_params_op)) { + var_remap_[binding->var->vid] = Downcast(VisitExpr(call->args[0])); + return; + } + } ExprMutator::VisitBinding_(binding); } diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index a10a1b5fe9e9..d596c60196f3 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -103,5 +103,13 @@ def test_vm_alloc_tensor(): tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5], "float32")) +def test_builtin_stop_lift_params(): + bb = rx.BlockBuilder() + x = rx.Var("x", rx.TensorStructInfo(shape=[4, 5], dtype="float32")) + x1 = rx.op.builtin.stop_lift_params(x) + x1 = bb.normalize(x1) + tvm.ir.assert_structural_equal(x1.struct_info, R.Tensor([4, 5], "float32")) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 530866a61fa3..b6189488404f 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -19,6 +19,7 @@ import tvm.testing from tvm import relax from tvm.script import relax as R, tir as T +from tvm.script import ir as I import numpy as np import tvm.topi.testing @@ -392,5 +393,52 @@ def func3( tvm.ir.assert_structural_equal(after, Expected) +def test_stop_lifting(): + @tvm.script.ir_module + class Before: + @R.function + def func1( + x: R.Tensor((256, 256), "float32"), + w1: R.Tensor((256, 256), "float32"), + ) -> R.Tensor((256, 256), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + w1_t = R.permute_dims(w1, [1, 0]) + w1_t1 = R.builtin.stop_lift_params(w1_t) + w1_add = R.add(w1_t1, R.const(1, "float32")) + y = R.matmul(x, w1_add) + R.output(y) + return y + + @I.ir_module + class Expected: + @R.function + def func1( + x: R.Tensor((256, 256), dtype="float32"), + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + w1_add: R.Tensor((256, 256), dtype="float32") = R.add(lv, R.const(1, "float32")) + y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add, out_dtype="void") + R.output(y) + return y + + @R.function + def func1_transform_params( + params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + lv1: R.Tensor((256, 256), dtype="float32") = R.permute_dims(lv, axes=[1, 0]) + gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 7ce789496095..b697b5f8134e 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1196,6 +1196,16 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")): _check(foo) +def test_builtin_ops(): + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")): + tensor = R.builtin.stop_lift_params(x) + gv = tensor + return gv + + _check(foo) + + def test_prim_value(): @R.function def foo():