From 033f56d34424729fd2f91495cd7bfe6ed6b8268c Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 29 Sep 2020 15:12:11 -0700 Subject: [PATCH 01/21] [TIR] Support build TIR function. --- python/tvm/driver/build_module.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 7ad48e19a1db..ac19d953a2c6 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,17 +159,30 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] + is_tir_schedule = False + # Phase 0 if isinstance(sch, schedule.Schedule): mod = form_irmodule(sch, args, name, binds) + elif isinstance(sch, tvm.tir.PrimFunc): + func = sch.with_attr("global_symbol", name) + if pass_ctx.config.get("tir.restricted_func"): + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({name: func}) + is_tir_schedule = True else: mod = sch pass_list = lower_phase0 # Phase 1 + pass_list += [tvm.tir.transform.InjectPrefetch()] + + if is_tir_schedule: + pass + # pass_list += [tvm.tir.transform.BufferFlatten()] + else: + pass_list += [tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers)] pass_list += [ - tvm.tir.transform.InjectPrefetch(), - tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), @@ -369,8 +382,8 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi ---- See the note on :any:`tvm.target` on target string format. """ - if isinstance(inputs, schedule.Schedule): - if args is None: + if isinstance(inputs, (schedule.Schedule, tvm.tir.PrimFunc)): + if args is None and isinstance(inputs, schedule.Schedule): raise ValueError("args must be given for build from schedule") input_mod = lower(inputs, args, name=name, binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): From 904fb01335e520d6e15c08dd283ace21461b9d3c Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 12 Oct 2020 17:25:55 -0700 Subject: [PATCH 02/21] [TIR] Support Return. --- include/tvm/tir/builtin.h | 4 ++ src/tir/op/builtin.cc | 4 ++ src/tir/transforms/make_packed_api.cc | 59 ++++++++++++++++++++++++- tests/python/unittest/test_tir_build.py | 16 +++++++ 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/python/unittest/test_tir_build.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index a150595ab551..bd1688ea425b 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -41,6 +41,10 @@ namespace tir { /*! \brief Collection of builtin intrinsics as ops */ namespace builtin { +/*! + * \brief Return value. + */ +TVM_DLL const Op& myreturn(); /*! * \brief Reinterpret the value using the target type. */ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 796b113a4054..77b4ac980e21 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -42,6 +42,10 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(myreturn) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_num_inputs(1); + TIR_DEFINE_BUILTIN_FUNC(likely) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation)) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7c4a8ef92724..63ae81274af9 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -41,6 +41,58 @@ namespace tvm { namespace tir { +class ReturnRewriter : public StmtMutator { + public: + explicit ReturnRewriter(Var ret_var, Var ret_tcode) + : ret_var_(ret_var), ret_tcode_(ret_tcode) {} + + Stmt VisitStmt_(const EvaluateNode* node) override { + Stmt ret = StmtMutator::VisitStmt_(node); + const EvaluateNode* eval = ret.as(); + CHECK(eval); + if (const CallNode* call = eval->value.as()) { + if (call->op.same_as(builtin::myreturn())) { + CHECK_EQ(call->args.size(), 1); + ret = WriteToOut(call->args[0], ret_var_, ret_tcode_); + } + } + return ret; + } + private: + std::pair ConvertForFFI(PrimExpr val) { + DataType dtype = val.dtype(); + if (dtype.is_int() || dtype.is_uint()) { + return {kTVMArgInt, Cast(DataType::Int(64), val)}; + } else if (dtype.is_float()) { + return {kTVMArgFloat, Cast(DataType::Float(64), val)}; + } else if (dtype.is_void()) { + return {kTVMNullptr, val}; + } else { + LOG(FATAL) << "data type " << dtype << " not supported yet"; + } + return {kTVMNullptr, val}; + } + + // convert val's data type to FFI data type, return type code + Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) { + auto p = ConvertForFFI(val); + int tcode = p.first; + val = p.second; + Stmt store_val = Store(ret_var_, val, 0, const_true()); + Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true()); + return SeqStmt({store_val, store_tcode}); + } + + Var ret_var_; + Var ret_tcode_; +}; + +Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { + ReturnRewriter rewriter(ret_var, ret_tcode); + return rewriter(body); +} + + inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } @@ -182,8 +234,9 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); } - Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, - StringImm(name_hint + "_compute_"), func_ptr->body); + Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode); + body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, + StringImm(name_hint + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { PrimExpr node = StringImm("default"); @@ -222,6 +275,7 @@ namespace transform { Pass MakePackedAPI(int num_unpacked_args) { auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) { + LOG(INFO) << "Before Make Packed API:\n" << m; IRModuleNode* mptr = m.CopyOnWrite(); std::vector > updates; @@ -239,6 +293,7 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& pair : updates) { mptr->AddUnchecked(pair.first, pair.second); } + LOG(INFO) << "After Make Packed API:\n" << m; return m; }; diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_build.py new file mode 100644 index 000000000000..bde5f2d6fa7a --- /dev/null +++ b/tests/python/unittest/test_tir_build.py @@ -0,0 +1,16 @@ +import tvm +from tvm import tir + +def add(): + a = tir.Var("a", "float32") + b = tir.Var("b", "float32") + c = a + b + c = tir.call_intrin("float32", "tir.myreturn", c) + c = tir.Evaluate(c) + func = tir.PrimFunc([a, b], c) + mod = tvm.IRModule({'add': func}) + func = tvm.build(mod['add']) + out = func(1.0, 2.0) + print(out) + +add() From 30a691b6212149faf8d03bb5eeb1ff7c28429297 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 10 Dec 2020 15:27:39 -0800 Subject: [PATCH 03/21] [TIR] Update. --- python/tvm/driver/build_module.py | 5 +---- src/tir/transforms/make_packed_api.cc | 2 -- tests/python/unittest/test_tir_build.py | 2 +- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index ac19d953a2c6..baa357b07640 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -177,10 +177,7 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): # Phase 1 pass_list += [tvm.tir.transform.InjectPrefetch()] - if is_tir_schedule: - pass - # pass_list += [tvm.tir.transform.BufferFlatten()] - else: + if not is_tir_schedule: pass_list += [tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers)] pass_list += [ tvm.tir.transform.BF16Legalize(), diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 63ae81274af9..eebb14e3a321 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -275,7 +275,6 @@ namespace transform { Pass MakePackedAPI(int num_unpacked_args) { auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) { - LOG(INFO) << "Before Make Packed API:\n" << m; IRModuleNode* mptr = m.CopyOnWrite(); std::vector > updates; @@ -293,7 +292,6 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& pair : updates) { mptr->AddUnchecked(pair.first, pair.second); } - LOG(INFO) << "After Make Packed API:\n" << m; return m; }; diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_build.py index bde5f2d6fa7a..2ba230e5debb 100644 --- a/tests/python/unittest/test_tir_build.py +++ b/tests/python/unittest/test_tir_build.py @@ -11,6 +11,6 @@ def add(): mod = tvm.IRModule({'add': func}) func = tvm.build(mod['add']) out = func(1.0, 2.0) - print(out) + assert out == 3.0 add() From 221a2f43bb2ec86a199e444a3f05c7640175df26 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 14 Dec 2020 13:32:53 -0800 Subject: [PATCH 04/21] [TIR] Rename to ret. --- include/tvm/tir/builtin.h | 2 +- src/tir/op/builtin.cc | 2 +- src/tir/transforms/make_packed_api.cc | 2 +- tests/python/unittest/test_tir_build.py | 7 ++++--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index bd1688ea425b..6a40d86b8984 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -44,7 +44,7 @@ namespace builtin { /*! * \brief Return value. */ -TVM_DLL const Op& myreturn(); +TVM_DLL const Op& ret(); /*! * \brief Reinterpret the value using the target type. */ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 77b4ac980e21..137c0347a170 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -42,7 +42,7 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_num_inputs(1); -TIR_DEFINE_BUILTIN_FUNC(myreturn) +TIR_DEFINE_BUILTIN_FUNC(ret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_num_inputs(1); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index eebb14e3a321..308befd69306 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -51,7 +51,7 @@ class ReturnRewriter : public StmtMutator { const EvaluateNode* eval = ret.as(); CHECK(eval); if (const CallNode* call = eval->value.as()) { - if (call->op.same_as(builtin::myreturn())) { + if (call->op.same_as(builtin::ret())) { CHECK_EQ(call->args.size(), 1); ret = WriteToOut(call->args[0], ret_var_, ret_tcode_); } diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_build.py index 2ba230e5debb..0e7b246b6e3b 100644 --- a/tests/python/unittest/test_tir_build.py +++ b/tests/python/unittest/test_tir_build.py @@ -1,11 +1,11 @@ import tvm from tvm import tir -def add(): +def test_scalar_add(): a = tir.Var("a", "float32") b = tir.Var("b", "float32") c = a + b - c = tir.call_intrin("float32", "tir.myreturn", c) + c = tir.call_intrin("float32", "tir.ret", c) c = tir.Evaluate(c) func = tir.PrimFunc([a, b], c) mod = tvm.IRModule({'add': func}) @@ -13,4 +13,5 @@ def add(): out = func(1.0, 2.0) assert out == 3.0 -add() +if __name__ == "__main__": + test_scalar_add() From 424093be2163920cbefabba8dd5863c5722864ae Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 14 Dec 2020 14:49:52 -0800 Subject: [PATCH 05/21] [TIR] Update build. --- python/tvm/driver/build_module.py | 18 +++++------------- tests/python/unittest/test_tir_build.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index baa357b07640..f226dda6b286 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,27 +159,18 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - is_tir_schedule = False # Phase 0 if isinstance(sch, schedule.Schedule): mod = form_irmodule(sch, args, name, binds) - elif isinstance(sch, tvm.tir.PrimFunc): - func = sch.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.restricted_func"): - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: func}) - is_tir_schedule = True else: mod = sch pass_list = lower_phase0 # Phase 1 - pass_list += [tvm.tir.transform.InjectPrefetch()] - - if not is_tir_schedule: - pass_list += [tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers)] pass_list += [ + tvm.tir.transform.InjectPrefetch(), + tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), tvm.tir.transform.BF16Legalize(), tvm.tir.transform.NarrowDataType(32), tvm.tir.transform.Simplify(), @@ -215,6 +206,7 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): optimize = tvm.transform.Sequential(pass_list) mod = optimize(mod) + print(mod) return mod @@ -379,8 +371,8 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi ---- See the note on :any:`tvm.target` on target string format. """ - if isinstance(inputs, (schedule.Schedule, tvm.tir.PrimFunc)): - if args is None and isinstance(inputs, schedule.Schedule): + if isinstance(inputs, schedule.Schedule): + if args is None: raise ValueError("args must be given for build from schedule") input_mod = lower(inputs, args, name=name, binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_build.py index 0e7b246b6e3b..339266173788 100644 --- a/tests/python/unittest/test_tir_build.py +++ b/tests/python/unittest/test_tir_build.py @@ -1,15 +1,20 @@ import tvm from tvm import tir +from tvm.ir.transform import PassContext def test_scalar_add(): - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") + a = tir.Var("a", "float32") + b = tir.Var("b", "float32") c = a + b - c = tir.call_intrin("float32", "tir.ret", c) + c = tir.call_intrin("float32", "tir.ret", c) c = tir.Evaluate(c) func = tir.PrimFunc([a, b], c) - mod = tvm.IRModule({'add': func}) - func = tvm.build(mod['add']) + func = func.with_attr("global_symbol", "main") + pass_ctx = PassContext.current() + if pass_ctx.config.get("tir.noalias", True): + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({'main': func}) + func = tvm.build(mod) out = func(1.0, 2.0) assert out == 3.0 From ecff372269719c5a48ac11512dab2a54fb4f1a1c Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 14 Dec 2020 14:51:18 -0800 Subject: [PATCH 06/21] [TIR] Update. --- python/tvm/driver/build_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index f226dda6b286..7ad48e19a1db 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -159,7 +159,6 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] - # Phase 0 if isinstance(sch, schedule.Schedule): mod = form_irmodule(sch, args, name, binds) @@ -206,7 +205,6 @@ def lower(sch, args, name="main", binds=None, simple_mode=False): optimize = tvm.transform.Sequential(pass_list) mod = optimize(mod) - print(mod) return mod From 71942344c9f95d5c959cbb2b048611644d85ae5f Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 14 Dec 2020 14:52:22 -0800 Subject: [PATCH 07/21] [TIR] ASF header. --- tests/python/unittest/test_tir_build.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_build.py index 339266173788..4a5b460f17a8 100644 --- a/tests/python/unittest/test_tir_build.py +++ b/tests/python/unittest/test_tir_build.py @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. import tvm from tvm import tir from tvm.ir.transform import PassContext From 86cb0f83bbe2cc25135b7f01ff45a3d2b7a44ed5 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 14 Dec 2020 17:10:59 -0800 Subject: [PATCH 08/21] [TIR] Fix lint. --- src/tir/transforms/make_packed_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 308befd69306..5da83f738eb4 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -58,6 +58,7 @@ class ReturnRewriter : public StmtMutator { } return ret; } + private: std::pair ConvertForFFI(PrimExpr val) { DataType dtype = val.dtype(); @@ -72,7 +73,6 @@ class ReturnRewriter : public StmtMutator { } return {kTVMNullptr, val}; } - // convert val's data type to FFI data type, return type code Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) { auto p = ConvertForFFI(val); From fe49d69a00a32e962aec1deadc536b430858dda3 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 15 Dec 2020 09:20:58 -0800 Subject: [PATCH 09/21] [TIR] Fix lint. --- src/tir/transforms/make_packed_api.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 5da83f738eb4..eed653208c99 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -43,8 +43,7 @@ namespace tir { class ReturnRewriter : public StmtMutator { public: - explicit ReturnRewriter(Var ret_var, Var ret_tcode) - : ret_var_(ret_var), ret_tcode_(ret_tcode) {} + explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} Stmt VisitStmt_(const EvaluateNode* node) override { Stmt ret = StmtMutator::VisitStmt_(node); @@ -92,7 +91,6 @@ Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { return rewriter(body); } - inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } From 2d4d1787a33064dabf331bb40941d902186c2933 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 15 Dec 2020 09:53:57 -0800 Subject: [PATCH 10/21] [TIR] Fix lint. --- tests/python/unittest/test_tir_build.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_build.py index 4a5b460f17a8..fb6c45cd0ea5 100644 --- a/tests/python/unittest/test_tir_build.py +++ b/tests/python/unittest/test_tir_build.py @@ -18,6 +18,7 @@ from tvm import tir from tvm.ir.transform import PassContext + def test_scalar_add(): a = tir.Var("a", "float32") b = tir.Var("b", "float32") @@ -29,10 +30,11 @@ def test_scalar_add(): pass_ctx = PassContext.current() if pass_ctx.config.get("tir.noalias", True): func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({'main': func}) + mod = tvm.IRModule({"main": func}) func = tvm.build(mod) out = func(1.0, 2.0) assert out == 3.0 + if __name__ == "__main__": test_scalar_add() From a8d42ac62b0961184a8785ece659aeb717e5b79b Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 8 Jan 2021 16:16:25 -0800 Subject: [PATCH 11/21] [TIR] Update --- include/tvm/tir/op_attr_types.h | 6 +++++- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 16 ++++++++++++++++ src/tir/op/builtin.cc | 2 +- src/tir/transforms/make_packed_api.cc | 4 ++-- tests/python/unittest/test_tir_build.py | 2 +- 6 files changed, 26 insertions(+), 6 deletions(-) diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index ec7fc172cde8..3dcc4b943a79 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -74,7 +74,11 @@ enum class CallEffectKind : int { /*! * \brief Embed opaque information in the Expr, cannot be codegen. */ - kEmbedInfo = 5 + kEmbedInfo = 5, + /*! + * \brief Function that changes control flow + */ + kControlJump = 6, }; /*! \brief Use integer to record the kind. */ diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 1aac55fa9920..901c89ed9106 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -35,7 +35,7 @@ from .function import PrimFunc from .op import call_packed, call_intrin, call_pure_extern, call_extern -from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace +from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index ca61be4fcd83..8ea29e5e9f14 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -221,6 +221,22 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): ) +def ret(val): + """Create a tir return expression + + Parameters + ---------- + val : Expr + The returned tir expression, whose data type is int, float or void pointer. + + Returns + ------- + ret : PrimExpr + The return expression + """ + return call_intrin(val.dtype, "tir.ret", val) + + def any(*args, span=None): """Create a new experssion of the union of all conditions in the arguments diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 137c0347a170..1117571c8b75 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -43,7 +43,7 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret) .set_num_inputs(1); TIR_DEFINE_BUILTIN_FUNC(ret) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) .set_num_inputs(1); TIR_DEFINE_BUILTIN_FUNC(likely) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index eed653208c99..0676fd37509b 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -48,10 +48,10 @@ class ReturnRewriter : public StmtMutator { Stmt VisitStmt_(const EvaluateNode* node) override { Stmt ret = StmtMutator::VisitStmt_(node); const EvaluateNode* eval = ret.as(); - CHECK(eval); + ICHECK(eval); if (const CallNode* call = eval->value.as()) { if (call->op.same_as(builtin::ret())) { - CHECK_EQ(call->args.size(), 1); + ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; ret = WriteToOut(call->args[0], ret_var_, ret_tcode_); } } diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_build.py index fb6c45cd0ea5..5a03d0ed070a 100644 --- a/tests/python/unittest/test_tir_build.py +++ b/tests/python/unittest/test_tir_build.py @@ -23,7 +23,7 @@ def test_scalar_add(): a = tir.Var("a", "float32") b = tir.Var("b", "float32") c = a + b - c = tir.call_intrin("float32", "tir.ret", c) + c = tir.ret(c) c = tir.Evaluate(c) func = tir.PrimFunc([a, b], c) func = func.with_attr("global_symbol", "main") From 7facb69493b0371c80a59fc1c62af50b01ebc191 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 8 Jan 2021 18:01:02 -0800 Subject: [PATCH 12/21] [TIR] Fix lint. --- python/tvm/tir/op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 8ea29e5e9f14..973fb164401e 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -229,7 +229,7 @@ def ret(val): val : Expr The returned tir expression, whose data type is int, float or void pointer. - Returns + Returns ------- ret : PrimExpr The return expression From f2578cd8b728f724cce8875f7288144b6117e6ed Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 14 Jan 2021 18:50:37 -0800 Subject: [PATCH 13/21] [TIR] Handle control flow jump. --- include/tvm/tir/op.h | 8 +++++ src/target/llvm/codegen_llvm.cc | 6 ++++ src/tir/op/op.cc | 4 +++ src/tir/transforms/make_packed_api.cc | 6 ++-- .../{test_tir_build.py => test_tir_base.py} | 31 +++++++++++++++---- 5 files changed, 47 insertions(+), 8 deletions(-) rename tests/python/unittest/{test_tir_build.py => test_tir_base.py} (74%) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 4a907fca951d..d2a85ed228ae 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -70,6 +70,14 @@ TVM_DLL Type GetType(const PrimExpr& expr); */ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); +/*! + * \brief Return the value. + * + * \param value The returned value. + * \return The return expression. + */ +TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); + /*! * Query the maximum possible value of dtype. * \param dtype The data type. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 70f094a186e7..fe5c8885af7b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -927,6 +927,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(then_value, then_value_block); value->addIncoming(else_value, else_value_block); return value; + } else if (op->op.same_as(builtin::ret())) { + ICHECK_EQ(Downcast(op->args[0])->value, 0); + builder_->CreateRet(ConstInt32(0)); + llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_); + builder_->SetInsertPoint(ret_dummy); + return ret_dummy; } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index b576fe4faee8..9fcb07149d19 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -145,6 +145,10 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } } +PrimExpr ret(PrimExpr value, Span span) { + return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); +} + // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 0676fd37509b..e3ec353dbcf6 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -60,6 +60,7 @@ class ReturnRewriter : public StmtMutator { private: std::pair ConvertForFFI(PrimExpr val) { + // convert val's data type to FFI data type, return type code DataType dtype = val.dtype(); if (dtype.is_int() || dtype.is_uint()) { return {kTVMArgInt, Cast(DataType::Int(64), val)}; @@ -72,14 +73,15 @@ class ReturnRewriter : public StmtMutator { } return {kTVMNullptr, val}; } - // convert val's data type to FFI data type, return type code + Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) { auto p = ConvertForFFI(val); int tcode = p.first; val = p.second; Stmt store_val = Store(ret_var_, val, 0, const_true()); Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true()); - return SeqStmt({store_val, store_tcode}); + Stmt ret_zero = Evaluate(tvm::ret(0)); + return SeqStmt({store_val, store_tcode, ret_zero}); } Var ret_var_; diff --git a/tests/python/unittest/test_tir_build.py b/tests/python/unittest/test_tir_base.py similarity index 74% rename from tests/python/unittest/test_tir_build.py rename to tests/python/unittest/test_tir_base.py index 5a03d0ed070a..8b9467145fa2 100644 --- a/tests/python/unittest/test_tir_build.py +++ b/tests/python/unittest/test_tir_base.py @@ -18,6 +18,15 @@ from tvm import tir from tvm.ir.transform import PassContext +def build_tir_func(func): + func = func.with_attr("global_symbol", "main") + pass_ctx = PassContext.current() + if pass_ctx.config.get("tir.noalias", True): + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({"main": func}) + func = tvm.build(mod) + return func + def test_scalar_add(): a = tir.Var("a", "float32") @@ -26,15 +35,25 @@ def test_scalar_add(): c = tir.ret(c) c = tir.Evaluate(c) func = tir.PrimFunc([a, b], c) - func = func.with_attr("global_symbol", "main") - pass_ctx = PassContext.current() - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) - mod = tvm.IRModule({"main": func}) - func = tvm.build(mod) + func = build_tir_func(func) out = func(1.0, 2.0) assert out == 3.0 +def test_control_flow_jump(): + ib = tvm.tir.ir_builder.create() + a = tir.Var("a", "float32") + b = tir.Var("b", "float32") + with ib.if_scope(True): + ib.emit(tir.Evaluate(tir.ret(a))) + ib.emit(tir.Evaluate(tir.ret(b))) + stmt = ib.get() + func = tir.PrimFunc([a, b], stmt) + func = build_tir_func(func) + out = func(1.0, 2.0) + assert out == 1.0 + + if __name__ == "__main__": test_scalar_add() + test_control_flow_jump() From b04c3a06d0222920a4ccf61b13843eaad99c050e Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 14 Jan 2021 21:00:12 -0800 Subject: [PATCH 14/21] [TIR] Fix lint. --- include/tvm/tir/op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index d2a85ed228ae..a5eb2d36ce7a 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -72,7 +72,7 @@ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); /*! * \brief Return the value. - * + * * \param value The returned value. * \return The return expression. */ From 19be07dfc066ca66c34900c619fc402ef7190ff3 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 14 Jan 2021 21:02:23 -0800 Subject: [PATCH 15/21] [TIR] Fix lint. --- tests/python/unittest/test_tir_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 8b9467145fa2..6e081a179059 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -18,6 +18,7 @@ from tvm import tir from tvm.ir.transform import PassContext + def build_tir_func(func): func = func.with_attr("global_symbol", "main") pass_ctx = PassContext.current() @@ -25,7 +26,7 @@ def build_tir_func(func): func = func.with_attr("tir.noalias", True) mod = tvm.IRModule({"main": func}) func = tvm.build(mod) - return func + return func def test_scalar_add(): From 5ab6d8320fdc63a6021e2875ba0bcf42ceac3a76 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 14 Jan 2021 21:05:58 -0800 Subject: [PATCH 16/21] [TIR] Fix lint. --- python/tvm/tir/op.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 973fb164401e..182264f0db92 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -257,10 +257,10 @@ def any(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _ffi_api._OpOr(args[0], args[1], span) + val = _ffi_api._OpOr(args[0], args[1], span) for i in range(2, len(args)): - ret = _ffi_api._OpOr(ret, args[i], span) - return ret + val = _ffi_api._OpOr(val, args[i], span) + return val def all(*args, span=None): @@ -284,10 +284,10 @@ def all(*args, span=None): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _ffi_api._OpAnd(args[0], args[1], span) + val = _ffi_api._OpAnd(args[0], args[1], span) for i in range(2, len(args)): - ret = _ffi_api._OpAnd(ret, args[i], span) - return ret + val = _ffi_api._OpAnd(val, args[i], span) + return val @tvm._ffi.register_func("tvm.default_trace_action") From 6525e9fa538fb9f6e97b0b528c587c1e9e860c02 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 14 Jan 2021 21:09:03 -0800 Subject: [PATCH 17/21] [TIR] Fix lint. --- include/tvm/tir/op.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index a5eb2d36ce7a..b5a62c907ed6 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -74,6 +74,7 @@ TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); * \brief Return the value. * * \param value The returned value. + * \param span The location of this operation in the source. * \return The return expression. */ TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); From 6c87c3d138f7535849a9861b1229f4ec47bc9740 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 15 Jan 2021 15:08:02 -0800 Subject: [PATCH 18/21] [TIR] Error message. --- src/target/llvm/codegen_llvm.cc | 6 +++++- src/tir/transforms/make_packed_api.cc | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index fe5c8885af7b..241b531a2a24 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -928,8 +928,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(else_value, else_value_block); return value; } else if (op->op.same_as(builtin::ret())) { - ICHECK_EQ(Downcast(op->args[0])->value, 0); + auto const* val = op->args[0].as(); + ICHECK(val) << "the tir.ret should be transformed to return zero before the llvm code generation."; + ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to return zero before the llvm code generation."; builder_->CreateRet(ConstInt32(0)); + // LLVM allows exactly one terminator in a single basic block + // append a new dummy basic block to avoid error. llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_); builder_->SetInsertPoint(ret_dummy); return ret_dummy; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index e3ec353dbcf6..8309b7d5d729 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -45,12 +45,20 @@ class ReturnRewriter : public StmtMutator { public: explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} + Stmt VisitStmt_(const ForNode* node) override { + if (node->for_type == ForType::Parallel) in_parallel_ += 1; + Stmt ret = StmtMutator::VisitStmt_(node); + if (node->for_type == ForType::Parallel) in_parallel_ -= 1; + return ret; + } + Stmt VisitStmt_(const EvaluateNode* node) override { Stmt ret = StmtMutator::VisitStmt_(node); const EvaluateNode* eval = ret.as(); ICHECK(eval); if (const CallNode* call = eval->value.as()) { if (call->op.same_as(builtin::ret())) { + ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope."; ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; ret = WriteToOut(call->args[0], ret_var_, ret_tcode_); } @@ -86,6 +94,7 @@ class ReturnRewriter : public StmtMutator { Var ret_var_; Var ret_tcode_; + int in_parallel_{0}; }; Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) { From 9737b62676dcbdd96c3657bea93eda50fc624dff Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 15 Jan 2021 15:13:03 -0800 Subject: [PATCH 19/21] [TIR] Fix lint. --- src/target/llvm/codegen_llvm.cc | 6 ++++-- src/tir/transforms/make_packed_api.cc | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 241b531a2a24..02af8171d75e 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -929,8 +929,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return value; } else if (op->op.same_as(builtin::ret())) { auto const* val = op->args[0].as(); - ICHECK(val) << "the tir.ret should be transformed to return zero before the llvm code generation."; - ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to return zero before the llvm code generation."; + ICHECK(val) << "the tir.ret should be transformed to return zero " + << "before the llvm code generation."; + ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " + << "return zero before the llvm code generation."; builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 8309b7d5d729..adbe78a6d627 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -50,7 +50,7 @@ class ReturnRewriter : public StmtMutator { Stmt ret = StmtMutator::VisitStmt_(node); if (node->for_type == ForType::Parallel) in_parallel_ -= 1; return ret; - } + } Stmt VisitStmt_(const EvaluateNode* node) override { Stmt ret = StmtMutator::VisitStmt_(node); From 06ac4a230cc20e23ac95dfe58e938858cb61de33 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 15 Jan 2021 15:17:32 -0800 Subject: [PATCH 20/21] [TIR] Fix lint. --- src/target/llvm/codegen_llvm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 02af8171d75e..32e0a29e1b65 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -930,9 +930,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::ret())) { auto const* val = op->args[0].as(); ICHECK(val) << "the tir.ret should be transformed to return zero " - << "before the llvm code generation."; + << "before the llvm code generation."; ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " - << "return zero before the llvm code generation."; + << "return zero before the llvm code generation."; builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. From 4dd827b4419a48f422cabe9e7d9190770d612a04 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 15 Jan 2021 15:28:54 -0800 Subject: [PATCH 21/21] [TIR] Fix lint. --- src/target/llvm/codegen_llvm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 32e0a29e1b65..34f3897cce88 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -932,7 +932,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { ICHECK(val) << "the tir.ret should be transformed to return zero " << "before the llvm code generation."; ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " - << "return zero before the llvm code generation."; + << "return zero before the llvm code generation."; builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error.