diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 86ac3d351614..17baaf3e657a 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -22,6 +22,7 @@ */ #ifdef TVM_LLVM_VERSION +#include #include #include #include @@ -30,6 +31,8 @@ #include +#include "intrin_rule_llvm.h" + namespace tvm { namespace codegen { @@ -140,8 +143,8 @@ TVM_REGISTER_OP("tir.exp10") TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -TVM_REGISTER_OP("tir.fma").set_attr("rocm.FLowerIntrinsic", - DispatchPureExternOCML); +TVM_REGISTER_OP("tir.fma").set_attr( + "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); TVM_REGISTER_OP("tir.log").set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML);