From 5f236eee23704ab3dcda94194af565f04aff6255 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Wed, 22 Mar 2023 08:12:52 -0700 Subject: [PATCH 1/7] lift param change --- include/tvm/relax/expr.h | 2 + src/relax/transform/lift_transform_params.cc | 22 +++- .../test_transform_lift_transform_params.py | 101 ++++++++++++++++++ 3 files changed, 123 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 0788193ee7c4..24c3abfbba04 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -983,6 +983,8 @@ constexpr const char* kCodegen = "Codegen"; constexpr const char* kComposite = "Composite"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; +/*! \brief Indicate the function cannot be lifted to transform_params*/ +constexpr const char* kStopLifting = "StopLifting"; } // namespace attr /*! \brief The extern function, which can represent packed function. */ diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 401a03dbe25f..7039d13e18ba 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -144,6 +144,8 @@ class TransformParamsFuncBuilder : public ExprMutator { */ class LiftTransformParamsPlanner : public ExprVisitor { public: + explicit LiftTransformParamsPlanner(IRModule mod) : mod_(mod) {} + LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) { for (int i = num_inputs; i < static_cast(function->params.size()); ++i) { builder_.AddInput(function->params[i]); @@ -168,6 +170,18 @@ class LiftTransformParamsPlanner : public ExprVisitor { if (!is_in_dataflow_block_) { can_lift = false; } + if (const auto* call = binding->value.as()) { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + if (call->op.same_as(call_tir_op_)) { + if (const auto* gv = call->args[0].as()) { + if (const auto* prim_func = mod_->Lookup(GetRef(gv)).as()) { + if (prim_func->HasNonzeroAttr(attr::kStopLifting)) { + can_lift = false; + } + } + } + } + } PostOrderVisit(binding->value, [&](const ObjectRef& obj) { if (const VarNode* var = obj.as()) { @@ -195,6 +209,8 @@ class LiftTransformParamsPlanner : public ExprVisitor { TransformParamsFuncBuilder builder_; // Whether we are in a dataflow block bool is_in_dataflow_block_{false}; + // The module + IRModule mod_; }; /*! @@ -203,7 +219,7 @@ class LiftTransformParamsPlanner : public ExprVisitor { */ class TransformParamsLifter : public ExprMutator { public: - explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module) {} + explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module), mod_(module) {} IRModule Lift() { auto mod = builder_->GetContextIRModule(); @@ -228,7 +244,7 @@ class TransformParamsLifter : public ExprMutator { private: Function RewriteFunc(const Function& func, int num_input, String new_func_name) { - LiftTransformParamsPlanner planner; + LiftTransformParamsPlanner planner(mod_); // Step 1: Create the plan of lifting transform params lift_plan_ = planner.Plan(func, num_input); @@ -288,6 +304,8 @@ class TransformParamsLifter : public ExprMutator { std::unordered_map param_remap_; // The plan of lifting the transform params LiftTransformParamsInfoPlan lift_plan_; + // The module + IRModule mod_; }; namespace transform { diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 530866a61fa3..c3f3d960abde 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -392,5 +392,106 @@ def func3( tvm.ir.assert_structural_equal(after, Expected) +def test_stop_lifting(): + @tvm.script.ir_module + class Before: + @R.function + def func1( + x: R.Tensor((256, 256), "float32"), + w1: R.Tensor((256, 256), "float32"), + ) -> R.Tensor((256, 256), "float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + w1_t = R.permute_dims(w1, [1, 0]) + w1_add = R.add(w1_t, R.const(1, "float32")) + y = R.matmul(x, w1_add) + R.output(y) + return y + + mod = relax.transform.LegalizeOps()(Before) + mod["add"] = mod["add"].with_attr("StopLifting", True) + after = relax.transform.LiftTransformParams()(mod) + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + rxplaceholder: T.Buffer((T.int64(256), T.int64(256)), "float32"), + T_add: T.Buffer((T.int64(256), T.int64(256)), "float32"), + ): + T.func_attr({"StopLifting": True, "tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(256), T.int64(256)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + T.float32(1) + + @T.prim_func + def matmul( + rxplaceholder: T.Buffer((T.int64(256), T.int64(256)), "float32"), + rxplaceholder_1: T.Buffer((T.int64(256), T.int64(256)), "float32"), + matmul_1: T.Buffer((T.int64(256), T.int64(256)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for i0, i1, k in T.grid(T.int64(256), T.int64(256), T.int64(256)): + with T.block("matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) + T.writes(matmul_1[v_i0, v_i1]) + with T.init(): + matmul_1[v_i0, v_i1] = T.float32(0) + matmul_1[v_i0, v_i1] = ( + matmul_1[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] + ) + + @T.prim_func + def transpose( + rxplaceholder: T.Buffer((T.int64(256), T.int64(256)), "float32"), + T_transpose: T.Buffer((T.int64(256), T.int64(256)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(256), T.int64(256)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(rxplaceholder[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] + + @R.function + def func1( + x: R.Tensor((256, 256), dtype="float32"), + params: R.Tuple(R.Tensor((256, 256), dtype="float32")), + ) -> R.Tensor((256, 256), dtype="float32"): + cls = Expected + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + w1_add = R.call_tir(cls.add, (lv,), out_sinfo=R.Tensor((256, 256), dtype="float32")) + y = R.call_tir( + cls.matmul, (x, w1_add), out_sinfo=R.Tensor((256, 256), dtype="float32") + ) + R.output(y) + return y + + @R.function + def func1_transform_params( + params: R.Tuple(R.Tensor((256, 256), dtype="float32")) + ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): + cls = Expected + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = params[0] + lv1 = R.call_tir( + cls.transpose, (lv,), out_sinfo=R.Tensor((256, 256), dtype="float32") + ) + gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main() From 0d7c835782bbb62df4964812cee5d7b111697879 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 23 Mar 2023 11:26:15 -0700 Subject: [PATCH 2/7] stop lift params op --- include/tvm/relax/expr.h | 2 - python/tvm/relax/op/unary.py | 17 +++++ python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/tensor/unary.cc | 3 + src/relax/op/tensor/unary.h | 4 + src/relax/transform/lift_transform_params.cc | 19 ++--- tests/python/relax/test_op_unary.py | 1 + .../test_transform_lift_transform_params.py | 74 +++---------------- .../test_tvmscript_parser_op_arith_cmp.py | 1 + 9 files changed, 49 insertions(+), 74 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 24c3abfbba04..0788193ee7c4 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -983,8 +983,6 @@ constexpr const char* kCodegen = "Codegen"; constexpr const char* kComposite = "Composite"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; -/*! \brief Indicate the function cannot be lifted to transform_params*/ -constexpr const char* kStopLifting = "StopLifting"; } // namespace attr /*! \brief The extern function, which can represent packed function. */ diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index 866d2a8273d6..270ad52b0078 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -527,3 +527,20 @@ def isnan(x: Expr) -> Expr: The computed result. """ return _ffi_api.isnan(x) # type: ignore + + +def stop_lift_params(x: Expr) -> Expr: + """ + An indicator that the consumers of input tensor should not be lifted to transform_params function + + Parameters + ---------- + x: relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The result tensor that is the same as input tensor + """ + return _ffi_api.stop_lift_params(x) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d3448916091d..eb76a2e5a70d 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -115,6 +115,7 @@ square, squeeze, sqrt, + stop_lift_params, subtract, tan, tanh, @@ -634,6 +635,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "square", "squeeze", "sqrt", + "stop_lift_params", "str", "strided_slice", "subtract", diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index f1117c1826c5..5df4e8c490d7 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -87,5 +87,8 @@ RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isfinite); RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isinf); RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isnan); +/***************** Indicator operators *****************/ +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(stop_lift_params, /*require_float_dtype=*/false); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h index 8f6404c5d9ed..9cd0687c979f 100644 --- a/src/relax/op/tensor/unary.h +++ b/src/relax/op/tensor/unary.h @@ -138,6 +138,10 @@ Expr isinf(Expr x); /*! \brief Check if input value is Nan. */ Expr isnan(Expr x); +/***************** Indicator operators *****************/ +/*! \brief An indicator that the consumers of input tensor should not be lifted to transform_params + * function*/ +Expr stop_lift_params(Expr x); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 7039d13e18ba..65ec0d4c8563 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -171,15 +171,9 @@ class LiftTransformParamsPlanner : public ExprVisitor { can_lift = false; } if (const auto* call = binding->value.as()) { - static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - if (call->op.same_as(call_tir_op_)) { - if (const auto* gv = call->args[0].as()) { - if (const auto* prim_func = mod_->Lookup(GetRef(gv)).as()) { - if (prim_func->HasNonzeroAttr(attr::kStopLifting)) { - can_lift = false; - } - } - } + static const Op& stop_lift_params_op = Op::Get("relax.stop_lift_params"); + if (call->op.same_as(stop_lift_params_op)) { + can_lift = false; } } @@ -284,6 +278,13 @@ class TransformParamsLifter : public ExprMutator { if (lift_plan_.lifted_bindings.count(binding->var)) { return; } + if (const auto* call = binding->value.as()) { + static const Op& stop_lift_params_op = Op::Get("relax.stop_lift_params"); + if (call->op.same_as(stop_lift_params_op)) { + var_remap_[binding->var->vid] = Downcast(VisitExpr(call->args[0])); + return; + } + } ExprMutator::VisitBinding_(binding); } diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py index 45336661a1ae..ae8662cc1f17 100644 --- a/tests/python/relax/test_op_unary.py +++ b/tests/python/relax/test_op_unary.py @@ -83,6 +83,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r (relax.op.sqrt, True), (relax.op.tan, True), (relax.op.tanh, True), + (relax.op.stop_lift_params, False), ) diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index c3f3d960abde..ec74c1f36668 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -19,6 +19,7 @@ import tvm.testing from tvm import relax from tvm.script import relax as R, tir as T +from tvm.script import ir as I import numpy as np import tvm.topi.testing @@ -403,76 +404,23 @@ def func1( R.func_attr({"num_input": 1}) with R.dataflow(): w1_t = R.permute_dims(w1, [1, 0]) - w1_add = R.add(w1_t, R.const(1, "float32")) + w1_t1 = R.stop_lift_params(w1_t) + w1_add = R.add(w1_t1, R.const(1, "float32")) y = R.matmul(x, w1_add) R.output(y) return y - mod = relax.transform.LegalizeOps()(Before) - mod["add"] = mod["add"].with_attr("StopLifting", True) - after = relax.transform.LiftTransformParams()(mod) - - @tvm.script.ir_module + @I.ir_module class Expected: - @T.prim_func - def add( - rxplaceholder: T.Buffer((T.int64(256), T.int64(256)), "float32"), - T_add: T.Buffer((T.int64(256), T.int64(256)), "float32"), - ): - T.func_attr({"StopLifting": True, "tir.noalias": True}) - # with T.block("root"): - for ax0, ax1 in T.grid(T.int64(256), T.int64(256)): - with T.block("T_add"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax0, v_ax1]) - T.writes(T_add[v_ax0, v_ax1]) - T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + T.float32(1) - - @T.prim_func - def matmul( - rxplaceholder: T.Buffer((T.int64(256), T.int64(256)), "float32"), - rxplaceholder_1: T.Buffer((T.int64(256), T.int64(256)), "float32"), - matmul_1: T.Buffer((T.int64(256), T.int64(256)), "float32"), - ): - T.func_attr({"tir.noalias": True}) - # with T.block("root"): - for i0, i1, k in T.grid(T.int64(256), T.int64(256), T.int64(256)): - with T.block("matmul"): - v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) - T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1]) - T.writes(matmul_1[v_i0, v_i1]) - with T.init(): - matmul_1[v_i0, v_i1] = T.float32(0) - matmul_1[v_i0, v_i1] = ( - matmul_1[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1] - ) - - @T.prim_func - def transpose( - rxplaceholder: T.Buffer((T.int64(256), T.int64(256)), "float32"), - T_transpose: T.Buffer((T.int64(256), T.int64(256)), "float32"), - ): - T.func_attr({"tir.noalias": True}) - # with T.block("root"): - for ax0, ax1 in T.grid(T.int64(256), T.int64(256)): - with T.block("T_transpose"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(rxplaceholder[v_ax1, v_ax0]) - T.writes(T_transpose[v_ax0, v_ax1]) - T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] - @R.function def func1( x: R.Tensor((256, 256), dtype="float32"), params: R.Tuple(R.Tensor((256, 256), dtype="float32")), ) -> R.Tensor((256, 256), dtype="float32"): - cls = Expected with R.dataflow(): lv: R.Tensor((256, 256), dtype="float32") = params[0] - w1_add = R.call_tir(cls.add, (lv,), out_sinfo=R.Tensor((256, 256), dtype="float32")) - y = R.call_tir( - cls.matmul, (x, w1_add), out_sinfo=R.Tensor((256, 256), dtype="float32") - ) + w1_add: R.Tensor((256, 256), dtype="float32") = R.add(lv, R.const(1, "float32")) + y: R.Tensor((256, 256), dtype="float32") = R.matmul(x, w1_add, out_dtype="void") R.output(y) return y @@ -480,18 +428,18 @@ def func1( def func1_transform_params( params: R.Tuple(R.Tensor((256, 256), dtype="float32")) ) -> R.Tuple(R.Tensor((256, 256), dtype="float32")): - cls = Expected with R.dataflow(): lv: R.Tensor((256, 256), dtype="float32") = params[0] - lv1 = R.call_tir( - cls.transpose, (lv,), out_sinfo=R.Tensor((256, 256), dtype="float32") - ) + lv1: R.Tensor((256, 256), dtype="float32") = R.permute_dims(lv, axes=[1, 0]) gv: R.Tuple(R.Tensor((256, 256), dtype="float32")) = (lv1,) R.output(gv) return gv + mod = Before + after = relax.transform.LiftTransformParams()(mod) tvm.ir.assert_structural_equal(after, Expected) if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_stop_lifting() diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py index d43e9a626b66..a042e2c2cb57 100644 --- a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py +++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py @@ -59,6 +59,7 @@ def _check( (relax.op.sqrt,), (relax.op.tan,), (relax.op.tanh,), + (relax.op.stop_lift_params,), ) From 6a081b10b0a868619039413d5b85fed8dbdcc41a Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 23 Mar 2023 11:27:24 -0700 Subject: [PATCH 3/7] fix --- tests/python/relax/test_transform_lift_transform_params.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index ec74c1f36668..9c9172bc34c3 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -441,5 +441,4 @@ def func1_transform_params( if __name__ == "__main__": - # tvm.testing.main() - test_stop_lifting() + tvm.testing.main() From 5f0931f9e6de86eb44e7c333cc83f552268ea235 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 23 Mar 2023 11:29:27 -0700 Subject: [PATCH 4/7] fix --- python/tvm/relax/op/unary.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index 270ad52b0078..6435bf26c327 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -529,6 +529,9 @@ def isnan(x: Expr) -> Expr: return _ffi_api.isnan(x) # type: ignore +###################### Indicator operators ###################### + + def stop_lift_params(x: Expr) -> Expr: """ An indicator that the consumers of input tensor should not be lifted to transform_params function From ff8f17cb5a16b8ee4ab5de00c1839767dabee546 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Fri, 24 Mar 2023 14:00:52 -0700 Subject: [PATCH 5/7] move to builtin --- python/tvm/relax/op/builtin/builtin.py | 17 ++++++++++++++++ python/tvm/relax/op/unary.py | 20 ------------------- python/tvm/script/ir_builder/relax/ir.py | 2 +- src/relax/op/op.cc | 17 ++++++++++++++++ src/relax/op/tensor/unary.cc | 3 --- src/relax/op/tensor/unary.h | 4 ---- src/relax/transform/lift_transform_params.cc | 4 ++-- tests/python/relax/test_op_misc.py | 8 ++++++++ tests/python/relax/test_op_unary.py | 1 - .../test_transform_lift_transform_params.py | 2 +- tests/python/relax/test_tvmscript_parser.py | 10 ++++++++++ .../test_tvmscript_parser_op_arith_cmp.py | 1 - 12 files changed, 56 insertions(+), 33 deletions(-) diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py index 43bbd461bca8..8b181155804a 100644 --- a/python/tvm/relax/op/builtin/builtin.py +++ b/python/tvm/relax/op/builtin/builtin.py @@ -50,3 +50,20 @@ def alloc_tensor( runtime_device_index = PrimValue(runtime_device_index) return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore + + +def stop_lift_params(x: Expr) -> Expr: + """ + An indicator that the consumers of input tensor should not be lifted to transform_params function + + Parameters + ---------- + x: relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The result tensor that is the same as input tensor + """ + return _ffi_api.stop_lift_params(x) # type: ignore diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index 6435bf26c327..866d2a8273d6 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -527,23 +527,3 @@ def isnan(x: Expr) -> Expr: The computed result. """ return _ffi_api.isnan(x) # type: ignore - - -###################### Indicator operators ###################### - - -def stop_lift_params(x: Expr) -> Expr: - """ - An indicator that the consumers of input tensor should not be lifted to transform_params function - - Parameters - ---------- - x: relax.Expr - The input data - - Returns - ------- - result : relax.Expr - The result tensor that is the same as input tensor - """ - return _ffi_api.stop_lift_params(x) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index eb76a2e5a70d..f9104c94305a 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -115,7 +115,6 @@ square, squeeze, sqrt, - stop_lift_params, subtract, tan, tanh, @@ -130,6 +129,7 @@ zeros_like, nn, ) +from tvm.relax.op.builtin import stop_lift_params from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter from tvm.runtime import Object as tvm_Object diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 49df881dcb8b..b353cce27a73 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -513,5 +513,22 @@ Expr MakeCallTIRDyn(Expr func, Tuple args) { TVM_REGISTER_GLOBAL("relax.op.vm.call_tir_dyn").set_body_typed(MakeCallTIRDyn); +// builtin stop_lift_params +StructInfo InferStructInfoStopLiftParams(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnaryArith(call, ctx); +} + +RELAY_REGISTER_OP("relax.builtin.stop_lift_params") + .set_num_inputs(1) + .add_argument("x", "Expr", "The input data") + .set_attr("FInferStructInfo", InferStructInfoStopLiftParams); + +Expr MakeStopLiftParams(Expr x) { + static const Op& op = Op::Get("relax.builtin.stop_lift_params"); + return Call(op, {x}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 5df4e8c490d7..f1117c1826c5 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -87,8 +87,5 @@ RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isfinite); RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isinf); RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isnan); -/***************** Indicator operators *****************/ -RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(stop_lift_params, /*require_float_dtype=*/false); - } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h index 9cd0687c979f..8f6404c5d9ed 100644 --- a/src/relax/op/tensor/unary.h +++ b/src/relax/op/tensor/unary.h @@ -138,10 +138,6 @@ Expr isinf(Expr x); /*! \brief Check if input value is Nan. */ Expr isnan(Expr x); -/***************** Indicator operators *****************/ -/*! \brief An indicator that the consumers of input tensor should not be lifted to transform_params - * function*/ -Expr stop_lift_params(Expr x); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 65ec0d4c8563..88f04e4eac0e 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -171,7 +171,7 @@ class LiftTransformParamsPlanner : public ExprVisitor { can_lift = false; } if (const auto* call = binding->value.as()) { - static const Op& stop_lift_params_op = Op::Get("relax.stop_lift_params"); + static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params"); if (call->op.same_as(stop_lift_params_op)) { can_lift = false; } @@ -279,7 +279,7 @@ class TransformParamsLifter : public ExprMutator { return; } if (const auto* call = binding->value.as()) { - static const Op& stop_lift_params_op = Op::Get("relax.stop_lift_params"); + static const Op& stop_lift_params_op = Op::Get("relax.builtin.stop_lift_params"); if (call->op.same_as(stop_lift_params_op)) { var_remap_[binding->var->vid] = Downcast(VisitExpr(call->args[0])); return; diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index a10a1b5fe9e9..d596c60196f3 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -103,5 +103,13 @@ def test_vm_alloc_tensor(): tvm.ir.assert_structural_equal(alloc.struct_info, R.Tensor([4, 5], "float32")) +def test_builtin_stop_lift_params(): + bb = rx.BlockBuilder() + x = rx.Var("x", rx.TensorStructInfo(shape=[4, 5], dtype="float32")) + x1 = rx.op.builtin.stop_lift_params(x) + x1 = bb.normalize(x1) + tvm.ir.assert_structural_equal(x1.struct_info, R.Tensor([4, 5], "float32")) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py index ae8662cc1f17..45336661a1ae 100644 --- a/tests/python/relax/test_op_unary.py +++ b/tests/python/relax/test_op_unary.py @@ -83,7 +83,6 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r (relax.op.sqrt, True), (relax.op.tan, True), (relax.op.tanh, True), - (relax.op.stop_lift_params, False), ) diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 9c9172bc34c3..b6189488404f 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -404,7 +404,7 @@ def func1( R.func_attr({"num_input": 1}) with R.dataflow(): w1_t = R.permute_dims(w1, [1, 0]) - w1_t1 = R.stop_lift_params(w1_t) + w1_t1 = R.builtin.stop_lift_params(w1_t) w1_add = R.add(w1_t1, R.const(1, "float32")) y = R.matmul(x, w1_add) R.output(y) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 7ce789496095..b697b5f8134e 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1196,6 +1196,16 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")): _check(foo) +def test_builtin_ops(): + @R.function + def foo(x: R.Tensor(("m", "n"), dtype="float32")): + tensor = R.builtin.stop_lift_params(x) + gv = tensor + return gv + + _check(foo) + + def test_prim_value(): @R.function def foo(): diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py index a042e2c2cb57..d43e9a626b66 100644 --- a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py +++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py @@ -59,7 +59,6 @@ def _check( (relax.op.sqrt,), (relax.op.tan,), (relax.op.tanh,), - (relax.op.stop_lift_params,), ) From b400e6d90031c8863c9850a60aebff297ac9f794 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Fri, 24 Mar 2023 16:02:49 -0700 Subject: [PATCH 6/7] remove mod_ --- src/relax/transform/lift_transform_params.cc | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 88f04e4eac0e..88939bd1f5ea 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -144,8 +144,6 @@ class TransformParamsFuncBuilder : public ExprMutator { */ class LiftTransformParamsPlanner : public ExprVisitor { public: - explicit LiftTransformParamsPlanner(IRModule mod) : mod_(mod) {} - LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) { for (int i = num_inputs; i < static_cast(function->params.size()); ++i) { builder_.AddInput(function->params[i]); @@ -203,8 +201,6 @@ class LiftTransformParamsPlanner : public ExprVisitor { TransformParamsFuncBuilder builder_; // Whether we are in a dataflow block bool is_in_dataflow_block_{false}; - // The module - IRModule mod_; }; /*! @@ -213,7 +209,7 @@ class LiftTransformParamsPlanner : public ExprVisitor { */ class TransformParamsLifter : public ExprMutator { public: - explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module), mod_(module) {} + explicit TransformParamsLifter(const IRModule& module) : ExprMutator(module) {} IRModule Lift() { auto mod = builder_->GetContextIRModule(); @@ -238,7 +234,7 @@ class TransformParamsLifter : public ExprMutator { private: Function RewriteFunc(const Function& func, int num_input, String new_func_name) { - LiftTransformParamsPlanner planner(mod_); + LiftTransformParamsPlanner planner; // Step 1: Create the plan of lifting transform params lift_plan_ = planner.Plan(func, num_input); @@ -305,8 +301,6 @@ class TransformParamsLifter : public ExprMutator { std::unordered_map param_remap_; // The plan of lifting the transform params LiftTransformParamsInfoPlan lift_plan_; - // The module - IRModule mod_; }; namespace transform { From 712b5dfc2b1f0549016c44767cabc40adca5f07a Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Fri, 24 Mar 2023 17:18:31 -0700 Subject: [PATCH 7/7] fix lint --- python/tvm/relax/op/builtin/builtin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py index 8b181155804a..9dfb30bc7487 100644 --- a/python/tvm/relax/op/builtin/builtin.py +++ b/python/tvm/relax/op/builtin/builtin.py @@ -54,7 +54,8 @@ def alloc_tensor( def stop_lift_params(x: Expr) -> Expr: """ - An indicator that the consumers of input tensor should not be lifted to transform_params function + An indicator that the consumers of input tensor should not be + lifted to transform_params function Parameters ----------