From 025447c8c1bad813ef55beb59d70b0523f0a2e17 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 27 Apr 2021 16:36:15 -0700 Subject: [PATCH 01/16] Add legalization part --- src/target/intrin_rule.cc | 39 +++++++++++++----------- src/tir/transforms/lower_intrin.cc | 48 ++++++++++++++++++++---------- 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index bfc3fe6fcc8c..0e19e7bbd7dd 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -112,41 +112,45 @@ TVM_REGISTER_OP("tir.ceil") TVM_REGISTER_OP("tir.round") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.rsqrt") +TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); + +TVM_REGISTER_OP("tir.isfinite") .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - return one / sqrt(call->args[0]); + return isfinite(call->args[0]); }); -TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); - -TVM_REGISTER_OP("tir.sigmoid") +TVM_REGISTER_OP("tir.isinf") .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - return one / (one + exp(-call->args[0])); + return isinf(call->args[0]); }); -TVM_REGISTER_OP("tir.isfinite") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { +} // namespace intrin +namespace legalize { +using namespace tir; + +TVM_REGISTER_OP("tir.rsqrt") + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - return isfinite(call->args[0]); + auto one = make_const(call->args[0].dtype(), 1); + return one / sqrt(call->args[0]); }); -TVM_REGISTER_OP("tir.isinf") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { +TVM_REGISTER_OP("tir.sigmoid") + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - return isinf(call->args[0]); + auto one = make_const(call->args[0].dtype(), 1); + return one / (one + exp(-call->args[0])); }); TVM_REGISTER_OP("tir.q_multiply_shift") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; const tir::CallNode* call = e.as(); @@ -221,7 +225,6 @@ TVM_REGISTER_OP("tir.q_multiply_shift") return cast(lp_dtype, x); } }); - -} // namespace intrin +} // namespace legalize } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 4101891db699..6fe228f96223 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -40,26 +40,29 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") - : IRMutatorWithAnalyzer(analyzer) { - std::vector patterns_; - patterns_.push_back(target + ".FLowerIntrinsic"); - + template + inline std::vector> retrieve_attr_maps( + const std::string& type_name, const std::string& target, const std::string& mtriple) { + std::vector patterns; + patterns.push_back(target + "." + type_name); bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); if (is_llvm_aarch64) { - patterns_.push_back(target + ".aarch64.FLowerIntrinsic"); - } - - patterns_.push_back("default.FLowerIntrinsic"); - - fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma"); - if (target == "stackvm") { - support_bitwise_op_ = false; + patterns.push_back(target + ".aarch64." + type_name); } + patterns.push_back("default." + type_name); - for (const std::string& pattern : patterns_) + std::vector> attr_maps; + for (const std::string& pattern : patterns) if (Op::HasAttrMap(pattern)) - lower_intrin_maps_.push_back(Op::GetAttrMap(pattern)); + attr_maps.push_back(Op::GetAttrMap(pattern)); + return attr_maps; + } + + IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") + : IRMutatorWithAnalyzer(analyzer) { + std::vector patterns_; + lower_intrin_maps_ = retrieve_attr_maps("FLowerIntrinsic", target, mtriple); + legalize_maps_ = retrieve_attr_maps("FLegalize", target, mtriple); } PrimExpr VisitExpr_(const CallNode* op) final { @@ -78,6 +81,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } + for (const auto& f_legalize_map : legalize_maps_) { + FLowerIntrinsic f = f_legalize_map.get(GetRef(ptr_op), nullptr); + if (f != nullptr) { + PrimExpr e = GetRef(op); + PrimExpr r = f(e); + ICHECK(r.defined()) << "legalize rule must always return valid Expr"; + if (!r.same_as(e)) { + r = this->VisitExpr(r); + if (r.defined()) { + return r; + } + } + } + } } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -282,6 +299,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // patterns std::vector> lower_intrin_maps_; + std::vector> legalize_maps_; const PackedFunc* fma_{nullptr}; bool support_bitwise_op_{true}; }; From e2001f081be8e7c8378d3225b96d817d052ef87a Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 27 Apr 2021 17:32:26 -0700 Subject: [PATCH 02/16] Move legalization functions --- src/target/intrin_rule.cc | 28 ++++------- src/target/llvm/intrin_rule_llvm.cc | 70 ++++++++++++++------------- src/target/spirv/intrin_rule_spirv.cc | 53 ++++++++++---------- 3 files changed, 75 insertions(+), 76 deletions(-) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 0e19e7bbd7dd..c519cdc2bcd5 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -115,38 +115,30 @@ TVM_REGISTER_OP("tir.round") TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.isfinite") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - const CallNode* call = e.as(); - ICHECK(call != nullptr); - return isfinite(call->args[0]); - }); +} // namespace intrin -TVM_REGISTER_OP("tir.isinf") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { +namespace legalize { +using namespace tir; + +TVM_REGISTER_OP("tir.sigmoid") + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); return isinf(call->args[0]); }); -} // namespace intrin -namespace legalize { -using namespace tir; - -TVM_REGISTER_OP("tir.rsqrt") +TVM_REGISTER_OP("tir.isfinite") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - return one / sqrt(call->args[0]); + return isfinite(call->args[0]); }); -TVM_REGISTER_OP("tir.sigmoid") +TVM_REGISTER_OP("tir.isinf") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - return one / (one + exp(-call->args[0])); + return isinf(call->args[0]); }); TVM_REGISTER_OP("tir.q_multiply_shift") diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 2d30c2030685..6ac7a89cda66 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -30,6 +30,7 @@ namespace tvm { namespace codegen { namespace llvm { +namespace intrin { using tir::FLowerIntrinsic; TVM_REGISTER_OP("tir.prefetch") @@ -43,20 +44,6 @@ TVM_REGISTER_OP("tir.exp2") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -// TODO(tvm-team): migrate the legalization transformations as a separate -// set of rules in TIR that can be shared across backends. -TVM_REGISTER_OP("tir.exp10") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - using tir::make_zero; - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = exp(x * ln10); - return ret; - }); - TVM_REGISTER_OP("tir.fma").set_attr( "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); @@ -99,8 +86,37 @@ TVM_REGISTER_OP("tir.nearbyint") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); +TVM_REGISTER_OP("tir.pow").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); + +TVM_REGISTER_OP("tir.popcount") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); + +TVM_REGISTER_OP("tir.cos").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); + +TVM_REGISTER_OP("tir.sin").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); +} // namespace intrin + +namespace legalize { +using tir::FLegalize; + +TVM_REGISTER_OP("tir.exp10") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + using tir::make_zero; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr ln10 = make_const(x.dtype(), 2.302585093); + PrimExpr ret = exp(x * ln10); + return ret; + }); + TVM_REGISTER_OP("tir.tanh") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); @@ -118,14 +134,7 @@ TVM_REGISTER_OP("tir.tanh") return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); -TVM_REGISTER_OP("tir.pow").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); - -TVM_REGISTER_OP("tir.popcount") - .set_attr("llvm.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); - -TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLowerIntrinsic", +TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); @@ -135,11 +144,8 @@ TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLowerIntrinsic", return tan_x; }); -TVM_REGISTER_OP("tir.cos").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); - TVM_REGISTER_OP("tir.cosh") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); @@ -153,11 +159,8 @@ TVM_REGISTER_OP("tir.cosh") return ret; }); -TVM_REGISTER_OP("tir.sin").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); - TVM_REGISTER_OP("tir.sinh") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); @@ -171,8 +174,8 @@ TVM_REGISTER_OP("tir.sinh") return ret; }); -TVM_REGISTER_OP("tir.clz").set_attr( - "llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { +TVM_REGISTER_OP("tir.clz").set_attr( + "llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); ICHECK_EQ(call->args.size(), 1); @@ -186,6 +189,7 @@ TVM_REGISTER_OP("tir.clz").set_attr( return cast(call->dtype, clz); }); +} // namespace legalize } // namespace llvm } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index fa38f8fb0107..50f7eaa13006 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -30,8 +30,6 @@ namespace tvm { namespace codegen { namespace spirv { -using tir::FLowerIntrinsic; - // num_signature means number of arguments used to query signature template PrimExpr CallGLSLIntrin(PrimExpr e, const Array& args) { @@ -59,6 +57,8 @@ inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) { return CallGLSLIntrin(e); } +namespace intrin { +using tir::FLowerIntrinsic; TVM_REGISTER_OP("tir.floor") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); @@ -98,29 +98,6 @@ TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", TVM_REGISTER_OP("tir.tanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.clz").set_attr( - "vulkan.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); - PrimExpr arg = call->args[0]; - PrimExpr msb; - if (arg.dtype().bits() == 64) { - // SPIR-V FindUMsb intrinsic only supports 32 bit input - auto int32 = DataType::Int(32); - PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32); - PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg); - PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); - PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); - msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); - } else if (arg.dtype().bits() == 32) { - msb = CallGLSLIntrin(e); - } else { - LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; - } - return PrimExpr(arg.dtype().bits() - 1) - msb; - }); - // WebGPU rules. TVM_REGISTER_OP("tir.floor") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); @@ -151,7 +128,33 @@ TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", TVM_REGISTER_OP("tir.tanh") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); +} // namespace intrin +namespace legalize { + using tir::FLegalize; + TVM_REGISTER_OP("tir.clz").set_attr( + "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 1); + PrimExpr arg = call->args[0]; + PrimExpr msb; + if (arg.dtype().bits() == 64) { + // SPIR-V FindUMsb intrinsic only supports 32 bit input + auto int32 = DataType::Int(32); + PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32); + PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg); + PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); + PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); + msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); + } else if (arg.dtype().bits() == 32) { + msb = CallGLSLIntrin(e); + } else { + LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; + } + return PrimExpr(arg.dtype().bits() - 1) - msb; + }); +} // namespace legalize } // namespace spirv } // namespace codegen } // namespace tvm From 01f651e3d30ccb890649f7e4d254e91359f5bb99 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 27 Apr 2021 17:41:37 -0700 Subject: [PATCH 03/16] Fix clang format --- src/target/llvm/intrin_rule_llvm.cc | 43 +++++++++++++-------------- src/target/spirv/intrin_rule_spirv.cc | 4 +-- src/tir/transforms/lower_intrin.cc | 8 ++--- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 6ac7a89cda66..adbd1056d962 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -134,15 +134,13 @@ TVM_REGISTER_OP("tir.tanh") return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); -TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLegalize", - [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = - e.as(); - ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr tan_x = sin(x) / cos(x); - return tan_x; - }); +TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr tan_x = sin(x) / cos(x); + return tan_x; +}); TVM_REGISTER_OP("tir.cosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { @@ -174,20 +172,19 @@ TVM_REGISTER_OP("tir.sinh") return ret; }); -TVM_REGISTER_OP("tir.clz").set_attr( - "llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); - Array cargs; - cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); - cargs.push_back(IntImm(DataType::UInt(32), 2)); - cargs.push_back(call->args[0]); - cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef - // LLVM requires that the return type must match the first argument type - auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); - return cast(call->dtype, clz); - }); +TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 1); + Array cargs; + cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); + cargs.push_back(IntImm(DataType::UInt(32), 2)); + cargs.push_back(call->args[0]); + cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef + // LLVM requires that the return type must match the first argument type + auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); + return cast(call->dtype, clz); +}); } // namespace legalize } // namespace llvm diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 50f7eaa13006..eca7c4ce1700 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -131,8 +131,8 @@ TVM_REGISTER_OP("tir.tanh") } // namespace intrin namespace legalize { - using tir::FLegalize; - TVM_REGISTER_OP("tir.clz").set_attr( +using tir::FLegalize; +TVM_REGISTER_OP("tir.clz").set_attr( "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); ICHECK(call != nullptr); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 6fe228f96223..31bf10e6e82e 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -41,8 +41,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt_; template - inline std::vector> retrieve_attr_maps( - const std::string& type_name, const std::string& target, const std::string& mtriple) { + inline std::vector> retrieve_attr_maps(const std::string& type_name, + const std::string& target, + const std::string& mtriple) { std::vector patterns; patterns.push_back(target + "." + type_name); bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); @@ -53,8 +54,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { std::vector> attr_maps; for (const std::string& pattern : patterns) - if (Op::HasAttrMap(pattern)) - attr_maps.push_back(Op::GetAttrMap(pattern)); + if (Op::HasAttrMap(pattern)) attr_maps.push_back(Op::GetAttrMap(pattern)); return attr_maps; } From 690ed4d2fa2fa757ebddba49fbd9865240ec7a2e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 27 Apr 2021 23:46:10 -0700 Subject: [PATCH 04/16] Fix Merge Error --- src/target/intrin_rule.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index c519cdc2bcd5..7177df2fb6f2 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -124,7 +124,8 @@ TVM_REGISTER_OP("tir.sigmoid") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); - return isinf(call->args[0]); + auto one = make_const(call->args[0].dtype(), 1); + return one / (one + exp(-call->args[0])); }); TVM_REGISTER_OP("tir.isfinite") From 329a5565fdf1761501cb35260f65f17fd6632fb0 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 28 Apr 2021 10:54:42 -0700 Subject: [PATCH 05/16] Add blank line --- src/target/intrin_rule.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 7177df2fb6f2..942ec2ec167f 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -218,6 +218,7 @@ TVM_REGISTER_OP("tir.q_multiply_shift") return cast(lp_dtype, x); } }); + } // namespace legalize } // namespace codegen } // namespace tvm From a58120609c1fa67df92a5f4a05dc22f27c2b108b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Wed, 28 Apr 2021 16:24:55 -0700 Subject: [PATCH 06/16] Retrigger CI From bb20cf7a59cb613d84b5bfe7b9ee944fa42fbd1a Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 29 Apr 2021 09:51:39 -0700 Subject: [PATCH 07/16] Retrigger CI From 032a75e16c9fc770aeb2b9e29e4156283a1acf67 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 29 Apr 2021 11:30:39 -0700 Subject: [PATCH 08/16] Fix fma issue --- src/tir/transforms/lower_intrin.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 31bf10e6e82e..dcf6a852734a 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -54,7 +54,12 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { std::vector> attr_maps; for (const std::string& pattern : patterns) - if (Op::HasAttrMap(pattern)) attr_maps.push_back(Op::GetAttrMap(pattern)); + if (Op::HasAttrMap(pattern)) { + attr_maps.push_back(Op::GetAttrMap(pattern)); + if (fma_ == nullptr && type_name == "FLowerIntrinsic") { + fma_ = (*attr_maps.rbegin()).get(Op::Get("tir.fma"), nullptr); + } + } return attr_maps; } @@ -286,7 +291,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); + PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { @@ -300,7 +305,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // patterns std::vector> lower_intrin_maps_; std::vector> legalize_maps_; - const PackedFunc* fma_{nullptr}; + // only intrinsic lowering function for tir.fma is supported now + FLowerIntrinsic fma_{nullptr}; bool support_bitwise_op_{true}; }; From b276c862ba2aef16b1fbd43ffd2b2f5b8accd00d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 29 Apr 2021 17:22:02 -0700 Subject: [PATCH 09/16] Add fma test --- tests/python/unittest/test_tir_intrin.py | 70 +++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 8512d1c311eb..4241c45c28d0 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -16,9 +16,10 @@ # under the License. import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm import topi from tvm.contrib import utils, clang +from tvm.script import ty import numpy as np import ctypes import math @@ -184,6 +185,72 @@ def clz_np(x, dtype): np.testing.assert_equal(b.asnumpy(), ref) +@tvm.script.tir +class Module: + def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) + n = tir.var("int32") + stride = tir.var("int32") + stride_1 = tir.var("int32") + stride_2 = tir.var("int32") + stride_3 = tir.var("int32") + A_1 = tir.match_buffer( + A, + [n], + strides=[stride], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + B_1 = tir.match_buffer( + B, + [n], + strides=[stride_1], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + C_1 = tir.match_buffer( + C, + [n], + strides=[stride_2], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + d_1 = tir.match_buffer( + d, + [n], + strides=[stride_3], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + # body + for i in tir.serial(0, n): + d_1.data[(i * stride_3)] = ( + tir.load("float32", A_1.data, (i * stride)) + * tir.load("float32", B_1.data, (i * stride_1)) + ) + tir.load("float32", C_1.data, (i * stride_2)) + + +def test_fma(): + opt = tvm.transform.Sequential( + [ + tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))), + tvm.tir.transform.LowerIntrin(), + ] + ) + mod = opt(Module()) + assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" + assert int(mod["test_tir_fma"].body.body.value.args[0]) == 134 + + if __name__ == "__main__": test_nearbyint() test_unary_intrin() @@ -191,3 +258,4 @@ def clz_np(x, dtype): test_binary_intrin() test_ldexp() test_clz() + test_fma() From 22420f77039f4e5753ff34134d1f28203d783f3e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 29 Apr 2021 19:50:57 -0700 Subject: [PATCH 10/16] Remove llvm function id check --- tests/python/unittest/test_tir_intrin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 4241c45c28d0..79b2819212b7 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -248,7 +248,6 @@ def test_fma(): ) mod = opt(Module()) assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" - assert int(mod["test_tir_fma"].body.body.value.args[0]) == 134 if __name__ == "__main__": From d55ba6652925112e5650e4d6a18ce050cf17227f Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 29 Apr 2021 22:18:56 -0700 Subject: [PATCH 11/16] Rerun CI From ecb64785c0639be0737e37d5bbf13d8e06446fe9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Apr 2021 10:21:52 -0700 Subject: [PATCH 12/16] Rerun CI again From 5ebd22df662a569b6add71807b945b0131b5c829 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Apr 2021 10:27:44 -0700 Subject: [PATCH 13/16] Add missing rsqrt op --- src/target/intrin_rule.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 942ec2ec167f..64796fb43a7a 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -120,6 +120,14 @@ TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", namespace legalize { using namespace tir; +TVM_REGISTER_OP("tir.rsqrt") + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const CallNode* call = e.as(); + ICHECK(call != nullptr); + auto one = make_const(call->args[0].dtype(), 1); + return one / sqrt(call->args[0]); + }); + TVM_REGISTER_OP("tir.sigmoid") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); From 6e21b2b85cc46f7dac15dd5717ab91e37ed86018 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Apr 2021 11:39:42 -0700 Subject: [PATCH 14/16] Fix lowering order --- src/tir/transforms/lower_intrin.cc | 57 +++++++++--------------------- 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index dcf6a852734a..2555002d29b0 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -39,41 +39,34 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; + using FLowerGeneral = runtime::TypedPackedFunc; - template - inline std::vector> retrieve_attr_maps(const std::string& type_name, - const std::string& target, - const std::string& mtriple) { + IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") + : IRMutatorWithAnalyzer(analyzer) { std::vector patterns; - patterns.push_back(target + "." + type_name); + patterns.push_back(target + ".FLowerIntrinsic"); + patterns.push_back(target + ".FLegalize"); bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); if (is_llvm_aarch64) { - patterns.push_back(target + ".aarch64." + type_name); + patterns.push_back(target + ".aarch64.FLowerIntrinsic"); + patterns.push_back(target + ".aarch64.FLegalize"); } - patterns.push_back("default." + type_name); + patterns.push_back("default.FLowerIntrinsic"); + patterns.push_back("default.FLegalize"); - std::vector> attr_maps; for (const std::string& pattern : patterns) if (Op::HasAttrMap(pattern)) { - attr_maps.push_back(Op::GetAttrMap(pattern)); - if (fma_ == nullptr && type_name == "FLowerIntrinsic") { - fma_ = (*attr_maps.rbegin()).get(Op::Get("tir.fma"), nullptr); + attr_maps_.push_back(Op::GetAttrMap(pattern)); + if (fma_ == nullptr) { + fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr); } } - return attr_maps; - } - - IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") - : IRMutatorWithAnalyzer(analyzer) { - std::vector patterns_; - lower_intrin_maps_ = retrieve_attr_maps("FLowerIntrinsic", target, mtriple); - legalize_maps_ = retrieve_attr_maps("FLegalize", target, mtriple); } PrimExpr VisitExpr_(const CallNode* op) final { if (auto* ptr_op = op->op.as()) { - for (const auto& f_lower_intrin_map : lower_intrin_maps_) { - FLowerIntrinsic f = f_lower_intrin_map.get(GetRef(ptr_op), nullptr); + for (const auto& f_attr_map : attr_maps_) { + FLowerGeneral f = f_attr_map.get(GetRef(ptr_op), nullptr); if (f != nullptr) { PrimExpr e = GetRef(op); PrimExpr r = f(e); @@ -86,20 +79,6 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } } - for (const auto& f_legalize_map : legalize_maps_) { - FLowerIntrinsic f = f_legalize_map.get(GetRef(ptr_op), nullptr); - if (f != nullptr) { - PrimExpr e = GetRef(op); - PrimExpr r = f(e); - ICHECK(r.defined()) << "legalize rule must always return valid Expr"; - if (!r.same_as(e)) { - r = this->VisitExpr(r); - if (r.defined()) { - return r; - } - } - } - } } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -302,11 +281,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - // patterns - std::vector> lower_intrin_maps_; - std::vector> legalize_maps_; - // only intrinsic lowering function for tir.fma is supported now - FLowerIntrinsic fma_{nullptr}; + // attribute maps, shared only when FLegalize == FLowerIntrinsic + std::vector> attr_maps_; + FLowerGeneral fma_{nullptr}; bool support_bitwise_op_{true}; }; From 7d977e2240ea12352a6caad271709899babdd6f1 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Apr 2021 14:19:43 -0700 Subject: [PATCH 15/16] Update intrin_rule.cc --- src/target/intrin_rule.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 64796fb43a7a..e697d9b60273 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -118,6 +118,7 @@ TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", } // namespace intrin namespace legalize { + using namespace tir; TVM_REGISTER_OP("tir.rsqrt") From 9e589ea593fac0e03769f8e80069ea2c582fcc68 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 30 Apr 2021 17:20:17 -0700 Subject: [PATCH 16/16] Retrigger CI