From 3952b726caa63d9bca8da6db33fd46d68cd028bf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Dec 2022 19:16:00 +0900 Subject: [PATCH 1/6] introduce LowerToPrimFunc to lower Relay func to TIR prim func --- src/relay/backend/task_extraction.cc | 6 ++---- src/relay/backend/te_compiler_cache.cc | 23 +++++++++++++++++++++++ src/relay/backend/te_compiler_cache.h | 4 ++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index e7e677938e1a..fc45311e085d 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -59,7 +59,6 @@ Array ExtractTask(IRModule mod, Target target, using meta_schedule::ExtractedTask; using meta_schedule::ModuleEqual; using meta_schedule::ModuleHash; - backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter(); backend::BindParamsInModule(mod, params); // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); @@ -84,10 +83,9 @@ Array ExtractTask(IRModule mod, Target target, if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { return; } - auto [inputs_outputs, constants, fused_name] = - tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true); - if (Optional f = tir_converter(inputs_outputs, constants)) { + auto [f, fused_name] = tec::LowerToPrimFunc(relay_func, target, constant_name_supply); + if (f) { IRModule tir_mod = PrimFuncToIRModule(f.value()); lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod)); } diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 511f0a901d11..877976065ea2 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -1088,6 +1088,29 @@ std::tuple, Array, std::string> LowerTECompu return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_); } +std::pair, std::string> LowerToPrimFunc(const Function& relay_func, + Target target, + NameSupply constant_name_supply) { + auto [inputs_outputs, constants, fused_name] = + tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true); + auto tir_converter = backend::GetTIRConverter(); + return std::make_pair(tir_converter(inputs_outputs, constants), fused_name); +} + +tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) { + auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply("")); + if (f_opt) { + return f_opt.value(); + } + LOG(FATAL) << "Failed to convert the Relay function: " << AsText(relay_func, false); + return PrimFunc(); +} + +TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc") + .set_body_typed([](Function relay_func, Target target) { + return LowerToPrimFunc(relay_func, target); + }); + TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { auto tgt = tvm::Target("ext_dev"); LowerToTECompute lower_te_compute(tgt, NameSupply("")); diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index fcbf10477fdf..685e894f62d7 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -224,6 +224,10 @@ std::tuple, Array, std::string> LowerTECompu const Function& source_func, Target target, NameSupply constant_name_supply, bool return_inputs = true); +std::pair, std::string> LowerToPrimFunc(const Function& relay_func, + Target target, + NameSupply constant_name_supply); + /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. From 567360e28da7af1a2f817dc232c699d3b6c7f415 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Dec 2022 19:20:43 +0900 Subject: [PATCH 2/6] add doc --- src/relay/backend/te_compiler_cache.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 685e894f62d7..e0c99714d05e 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -215,7 +215,7 @@ Array GetShape(const Array& shape); * \brief Lowers Relay primitive Function to TE Compute * \param source_func The primitive function to be lowered. * \param target The target we want to create schedule for. - * \param constant_name_supply A name supplier for constants. + * \param constant_name_supply A name supplier for constants * across different invocations of this function. * \param return_inputs If true, prepend input tensors to the output array of tensors. * \return Tuple of the lowered TE compute, constant raw data, and fused function name. @@ -224,6 +224,14 @@ std::tuple, Array, std::string> LowerTECompu const Function& source_func, Target target, NameSupply constant_name_supply, bool return_inputs = true); +/*! + * \brief Lowers Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc. + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \param constant_name_supply A name supplier for constants + * across different invocations of this function. + * \return A pair of the created prim func and the name of the fused function. + */ std::pair, std::string> LowerToPrimFunc(const Function& relay_func, Target target, NameSupply constant_name_supply); From b87ed2bab4f266ff576b82d51ba83d7f90439a44 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Dec 2022 19:35:45 +0900 Subject: [PATCH 3/6] expose to python --- python/tvm/relay/backend/te_compiler.py | 23 +++++++++++++++++++++++ src/relay/backend/te_compiler_cache.h | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index 5594e36cb855..0dc48b695acf 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -412,3 +412,26 @@ def get(): The TE Compiler. """ return _backend._TECompilerGlobal() + + +def lower_to_primfunc(relay_func, target): + """Lowers Relay Function to TIR PrimFunc. + + Parameters + ---------- + relay_func: relay.Function + The source primitive function, created by FuseOps. + + target : Target + The target we want to create schedule for. + + Returns + ------- + prim_func : tir.PrimFunc + The created prim func. + """ + f = tvm._ffi.get_global_func("relay.backend.LowerToPrimFunc") + assert f is not None, "relay.backend.LowerToPrimFunc does not exist. " + + with target: + return f(relay_func, target) diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index e0c99714d05e..6192dd21fb54 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -226,7 +226,7 @@ std::tuple, Array, std::string> LowerTECompu /*! * \brief Lowers Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc. - * \param source_func The primitive function to be lowered. + * \param relay_func The primitive function to be lowered. * \param target The target we want to create schedule for. * \param constant_name_supply A name supplier for constants * across different invocations of this function. From 2d2f703ff0af9ca8df25ca4952ad1d58923a8dd5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Dec 2022 19:39:17 +0900 Subject: [PATCH 4/6] adding test --- python/tvm/relay/backend/te_compiler.py | 4 +- src/relay/backend/te_compiler_cache.cc | 1 + src/relay/backend/te_compiler_cache.h | 10 ++-- ..._plan_update_buffer_allocation_location.py | 56 +++++++++++++++++++ 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index 0dc48b695acf..6c3930505a5c 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -415,7 +415,7 @@ def get(): def lower_to_primfunc(relay_func, target): - """Lowers Relay Function to TIR PrimFunc. + """Lower Relay Function to TIR PrimFunc. Parameters ---------- @@ -423,7 +423,7 @@ def lower_to_primfunc(relay_func, target): The source primitive function, created by FuseOps. target : Target - The target we want to create schedule for. + The target we want to create a schedule for. Returns ------- diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 877976065ea2..2480594d5ece 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -1099,6 +1099,7 @@ std::pair, std::string> LowerToPrimFunc(const Function& tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) { auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply("")); + (void)_; // to suppress -Werror=unused-variable warning if (f_opt) { return f_opt.value(); } diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 6192dd21fb54..0e4a77c16354 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -212,9 +212,9 @@ class CCacheValue : public ObjectRef { Array GetShape(const Array& shape); /*! - * \brief Lowers Relay primitive Function to TE Compute + * \brief Lower Relay primitive Function to TE Compute * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. + * \param target The target we want to create a schedule for. * \param constant_name_supply A name supplier for constants * across different invocations of this function. * \param return_inputs If true, prepend input tensors to the output array of tensors. @@ -225,9 +225,9 @@ std::tuple, Array, std::string> LowerTECompu bool return_inputs = true); /*! - * \brief Lowers Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc. + * \brief Lower Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc. * \param relay_func The primitive function to be lowered. - * \param target The target we want to create schedule for. + * \param target The target we want to create a schedule for. * \param constant_name_supply A name supplier for constants * across different invocations of this function. * \return A pair of the created prim func and the name of the fused function. @@ -239,7 +239,7 @@ std::pair, std::string> LowerToPrimFunc(const Function& /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. + * \param target The target we want to create a schedule for. * \param global_var_supply A name supplier for global variables. * \param constant_name_supply A name supplier for constants. * \return Pair of schedule and cache. diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 92e3cbd66e2f..0a8a0dd59fbf 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -14,10 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm import tvm.testing from tvm import te from tvm.script import tir as T +from tvm import relay, tir +from tvm.relay.backend.te_compiler import lower_to_primfunc +from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN def _check(original, transformed): @@ -360,5 +365,56 @@ def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]): _check(before, after) +def test_allocate_const_after_tensorize(): + i_size, o_size, h_size, w_size = 64, 64, 56, 56 + k_height_size = k_width_size = 3 + w_shape = (o_size, i_size, k_height_size, k_width_size) + + data = relay.var("data", shape=(1, i_size, h_size, w_size), dtype="uint8") + weight = relay.var("weight", shape=w_shape, dtype="uint8") + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(k_height_size, k_width_size), + channels=o_size, + padding=(0, 0), + strides=(1, 1), + out_dtype="int32", + ) + mod = tvm.IRModule.from_expr(conv2d) + + executor = relay.backend.Executor("graph", {"link-params": True}) + mod = mod.with_attr("executor", executor) + + weight_np = np.random.uniform(1, 10, size=w_shape).astype("uint8") + + target = tvm.target.Target("hexagon") + + with tvm.transform.PassContext(opt_level=3): + opt_mod, _ = relay.optimize(mod, params={"weight": weight_np}, target=target) + + conv2d_func = opt_mod["main"].body.args[0].op + prim_func = lower_to_primfunc(conv2d_func, target) + + sch = tir.Schedule(prim_func) + block = sch.get_block("conv2d_NCHWc_int8") + loops = sch.get_loops(block) + + sch.reorder(loops[8], loops[4], loops[-1]) + sch.decompose_reduction(block, loops[1]) + sch.tensorize(loops[4], VRMPY_u8u8i32_INTRIN) + + seq = tvm.transform.Sequential( + [ + tvm.tir.transform.LowerInitBlock(), + tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), + ] + ) + + # The following error is emitted if AllocateConst nodes are not correctly handled: + # Check failed: (buffer_data_to_buffer_.count(source_var)) is false: + _ = seq(sch.mod) + + if __name__ == "__main__": tvm.testing.main() From 848ac096adf5c0f14d1935468565c4ce89a5e4af Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 13 Dec 2022 20:45:15 +0900 Subject: [PATCH 5/6] another minor doc update --- python/tvm/relay/backend/te_compiler.py | 2 +- src/relay/backend/te_compiler_cache.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index 6c3930505a5c..814e79329019 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -423,7 +423,7 @@ def lower_to_primfunc(relay_func, target): The source primitive function, created by FuseOps. target : Target - The target we want to create a schedule for. + The compilation target. Returns ------- diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 0e4a77c16354..76939a923cdf 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -214,7 +214,7 @@ Array GetShape(const Array& shape); /*! * \brief Lower Relay primitive Function to TE Compute * \param source_func The primitive function to be lowered. - * \param target The target we want to create a schedule for. + * \param target The compilation target. * \param constant_name_supply A name supplier for constants * across different invocations of this function. * \param return_inputs If true, prepend input tensors to the output array of tensors. @@ -227,7 +227,7 @@ std::tuple, Array, std::string> LowerTECompu /*! * \brief Lower Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc. * \param relay_func The primitive function to be lowered. - * \param target The target we want to create a schedule for. + * \param target The compilation target. * \param constant_name_supply A name supplier for constants * across different invocations of this function. * \return A pair of the created prim func and the name of the fused function. @@ -239,7 +239,7 @@ std::pair, std::string> LowerToPrimFunc(const Function& /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. - * \param target The target we want to create a schedule for. + * \param target The compilation target. * \param global_var_supply A name supplier for global variables. * \param constant_name_supply A name supplier for constants. * \return Pair of schedule and cache. From feba9c3cfb1e3c9fc42452d890d4d96e6c8997f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 14 Dec 2022 13:59:59 +0900 Subject: [PATCH 6/6] Verify that the input is a primitive function --- src/relay/backend/te_compiler_cache.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 2480594d5ece..d71cbcfc667d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -1091,6 +1091,9 @@ std::tuple, Array, std::string> LowerTECompu std::pair, std::string> LowerToPrimFunc(const Function& relay_func, Target target, NameSupply constant_name_supply) { + ICHECK(relay_func->HasNonzeroAttr(attr::kPrimitive)) + << "The input must be a Relay primitive function."; + auto [inputs_outputs, constants, fused_name] = tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true); auto tir_converter = backend::GetTIRConverter();