From 89f4a0daac58874bfab287a24d8b6c671f53849c Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 12 Aug 2025 19:17:02 -0400 Subject: [PATCH] [CODEGEN][REFACTOR] tir.call_llvm_intrin to remove nargs This PR refactors the tir.call_llvm_intrin to omit the first nargs argument in the beginning. Previously the nargs was introduced when prefetch have different number of signature. The previous reason no longer stands as of now, and it is less intuitive to attach nargs for the call_llvm_intrin, where nargs directly appears in number of argument. After the update, tir.call_llvm_intrin can directly pass in the arguments as it is. --- include/tvm/tir/builtin.h | 2 +- include/tvm/tir/stmt.h | 5 ---- python/tvm/tir/tensor_intrin/arm_cpu.py | 18 +------------ python/tvm/tir/tensor_intrin/hexagon.py | 3 --- python/tvm/tir/tensor_intrin/rocm.py | 3 --- python/tvm/tir/tensor_intrin/x86.py | 3 --- src/target/llvm/codegen_arm.cc | 7 +---- src/target/llvm/codegen_llvm.cc | 27 ++++--------------- src/target/llvm/intrin_rule_llvm.cc | 1 - src/target/llvm/intrin_rule_llvm.h | 11 ++++++-- src/target/llvm/llvm_instance.h | 15 +++++++++++ src/target/llvm/llvm_module.cc | 15 +---------- src/tir/transforms/vectorize_loop.cc | 12 +++------ .../codegen/test_target_codegen_llvm.py | 8 +++--- .../test_hexagon/test_async_dma_pipeline.py | 4 --- .../test_hexagon/test_meta_schedule.py | 2 +- .../contrib/test_hexagon/test_parallel_hvx.py | 3 --- .../test_parallel_hvx_load_vtcm.py | 4 --- ...eta_schedule_postproc_rewrite_tensorize.py | 1 - .../test_meta_schedule_trace_apply.py | 2 +- tests/python/tir-base/test_tir_ops.py | 4 +-- ..._transform_lower_cross_thread_reduction.py | 2 +- .../test_tir_transform_vectorize.py | 12 +++------ .../tvmscript/test_tvmscript_printer_tir.py | 14 +++++----- .../tvmscript/test_tvmscript_roundtrip.py | 5 ---- 25 files changed, 55 insertions(+), 128 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index b4ed44fbff32..d3573c925daf 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -225,7 +225,7 @@ TVM_DLL const Op& call_spirv_pure_glsl450(); // TODO(tvm-team) revisit the builtins below // some of them can simply become ops with special codegen attr. /*! - * \brief Prefetch a cacheline + * \brief same signature as llvm.prefetch */ TVM_DLL const Op& prefetch(); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 37410b1271cc..bbdb7c272ed8 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1106,11 +1106,6 @@ constexpr const char* pragma_import_c = "pragma_import_c"; constexpr const char* pragma_import_llvm = "pragma_import_llvm"; /*! \brief Try to modify the AST to support Tensor Core */ constexpr const char* pragma_tensor_core = "pragma_tensor_core"; -/*! - * \brief Mark of prefetch scope, value=offset, - * run prefetch of Tensor on the current loop scope - */ -constexpr const char* prefetch_scope = "prefetch_scope"; /*! * \brief Marks the layout transforms to be used for a tensor. * diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py index a6f3538846e7..0a5c0ea3a51a 100644 --- a/python/tvm/tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -74,7 +74,6 @@ def neon_4x4_i8i8i32_impl( multiply_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), - T.uint32(2), vec_a, vec_b_low, dtype="int16x8", @@ -82,7 +81,6 @@ def neon_4x4_i8i8i32_impl( pairwise_reduction_low = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), - T.uint32(1), multiply_low, dtype="int32x4", ) @@ -91,7 +89,6 @@ def neon_4x4_i8i8i32_impl( multiply_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), - T.uint32(2), vec_a, vec_b_high, dtype="int16x8", @@ -99,14 +96,12 @@ def neon_4x4_i8i8i32_impl( pairwise_reduction_high = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), - T.uint32(1), multiply_high, dtype="int32x4", ) C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), - T.uint32(2), pairwise_reduction_low, pairwise_reduction_high, dtype="int32x4", @@ -159,7 +154,6 @@ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None: C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id(f"llvm.aarch64.neon.{instr}"), - T.uint32(3), vec_c, vec_a, vec_b, @@ -311,7 +305,6 @@ def impl(): T.call_llvm_intrin( "void", "llvm.aarch64.sme.ld1w.horiz", - T.uint32(4), predicate, input_ptr, sub_tile, @@ -335,7 +328,6 @@ def impl(): T.call_llvm_intrin( "void", "llvm.aarch64.sme.st1w.vert", - T.uint32(4), predicate, output_ptr, sub_tile, @@ -438,7 +430,6 @@ def impl(): T.call_llvm_intrin( "void", "llvm.aarch64.sme.ld1h.horiz", - T.uint32(4), ptrue_fp16, input_ptr, sub_tile_idx, @@ -450,7 +441,6 @@ def impl(): T.call_llvm_intrin( "void", "llvm.aarch64.sme.ld1h.horiz", - T.uint32(4), ptrue_fp16, input_ptr, sub_tile_idx, @@ -467,7 +457,6 @@ def impl(): T.call_llvm_intrin( "void", "llvm.aarch64.sme.st1w.vert", - T.uint32(4), ptrue_fp32, output_ptr, sub_tile_idx, @@ -479,7 +468,6 @@ def impl(): T.call_llvm_intrin( "void", "llvm.aarch64.sme.st1w.vert", - T.uint32(4), ptrue_fp32, output_ptr, sub_tile_idx + 2, @@ -692,7 +680,6 @@ def impl(): T.call_llvm_intrin( "void", fmopa_intrin, - T.uint32(5), sub_tile, input_1[1], input_2[1], @@ -713,7 +700,6 @@ def impl(): T.call_llvm_intrin( "void", "llvm.aarch64.sme.st1w.horiz", - T.uint32(4), _create_active_lane_mask( C, (vert_offset + slice_idx, horiz_offset), M ), @@ -752,9 +738,7 @@ def impl(c: T.handle) -> None: T.reads() T.writes(C[0:SVF2, 0:SVF2]) clear_all_tiles = T.int32(255) - T.evaluate( - T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", T.uint32(1), clear_all_tiles) - ) + T.evaluate(T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", clear_all_tiles)) return desc, impl diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 22dd9a977c65..631d6b353240 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -107,7 +107,6 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"), - T.uint32(3), C[T.ramp(T.int32(0), 1, 32)], B_i32x32, A_i32, @@ -149,7 +148,6 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"), - T.uint32(3), C[T.ramp(T.int32(0), 1, 32)], T.broadcast(A_i32, 32), B_i32x32, @@ -191,7 +189,6 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vdmpyhvsat.acc.128B"), - T.uint32(3), C[T.ramp(T.int32(0), 1, 32)], T.Broadcast(A_i32, 32), B_i32x32, diff --git a/python/tvm/tir/tensor_intrin/rocm.py b/python/tvm/tir/tensor_intrin/rocm.py index 12dabfb2cdb3..bfac2ca1d25b 100644 --- a/python/tvm/tir/tensor_intrin/rocm.py +++ b/python/tvm/tir/tensor_intrin/rocm.py @@ -39,7 +39,6 @@ def sdot4( C[0] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"), - T.uint32(4), T.reinterpret(A.vload([0], "int8x4"), dtype="int32"), T.reinterpret(B.vload([0], "int8x4"), dtype="int32"), T.int32(0), @@ -337,7 +336,6 @@ def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None: T.launch_thread(tx, WARP_SIZE) C[tx, 0:local_size_out] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id(mfma_intrin), - T.uint32(6), A[tx, 0:local_size], B[tx, 0:local_size], C[tx, 0:local_size_out], @@ -365,7 +363,6 @@ def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None: C[tx, 0:local_size_out] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id(mfma_intrin), - T.uint32(6), T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]), T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]), C[tx, 0:local_size_out], diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py index b4b6f07cd90e..8f9518ce459f 100644 --- a/python/tvm/tir/tensor_intrin/x86.py +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -59,7 +59,6 @@ def dot_product_16x4_u8i8i32_vnni( C[T.ramp(T.int32(0), 1, 16)] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), - T.uint32(3), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, @@ -86,7 +85,6 @@ def dot_product_16x4_u8i8i32_avx512( Red = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"), - T.uint32(2), A_u8x64, B_i8x64, dtype="int16x32", @@ -94,7 +92,6 @@ def dot_product_16x4_u8i8i32_avx512( C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"), - T.uint32(2), Red, T.int16x32(1), dtype="int32x16", diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 4abe66710755..3adcfc82bba8 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -67,7 +67,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { using namespace tir; - const PrimExpr& e = call->args[2]; + const PrimExpr& e = call->args[1]; llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop; llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu; @@ -77,7 +77,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { (total_size != 128 && total_size != 64)) { Array vcnt_args; vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); - vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); } @@ -101,14 +100,12 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { ICHECK(c0 != nullptr); Array vcnt8_args; vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); - vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); - vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); if (call->dtype.bits() == 16) { @@ -118,7 +115,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { // Accumulation 16->32bit Array vcnt32_args; vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); - vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); if (call->dtype.bits() == 32) { @@ -128,7 +124,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { // Accumulation 32->64bit Array vcnt64_args; vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); - vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 21dd62efa9b7..3fdce3a1031d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1351,34 +1351,18 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { - ICHECK_GE(op->args.size(), 2U); + ICHECK_GE(op->args.size(), 1U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); - int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector arg_type; - for (size_t i = 2; i < op->args.size(); ++i) { + for (size_t i = 1; i < op->args.size(); ++i) { arg_value.push_back(MakeValue(op->args[i])); - if (i - 2 < static_cast(num_signature)) { - arg_type.push_back(arg_value.back()->getType()); - } + arg_type.push_back(arg_value.back()->getType()); } - // LLVM's prefetch intrinsic returns "void", while TVM's prefetch - // returns int32. This causes problems because prefetch is one of - // those intrinsics that is generated automatically via the - // tvm.intrin.rule mechanism. Any other intrinsic with a type - // mismatch will have to be treated specially here. - // TODO(kparzysz-quic): fix this once TVM prefetch uses the same - // type as LLVM. - llvm::Type* return_type = - (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) : t_void_; + llvm::Type* return_type = GetLLVMType(GetRef(op)); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " -#if TVM_LLVM_VERSION >= 130 - << llvm::Intrinsic::getBaseName(id).str(); -#else - << llvm::Intrinsic::getName(id, {}); -#endif - + << llvmGetIntrinName(id); // In earlier versions of LLVM's, the prefetch intrinsic is not // overloaded, and always takes the first argument as i8*. If // this is the case, this argument should insert a cast to i8*. @@ -1391,7 +1375,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { builder_->CreatePointerCast(arg_value[0], llvmGetPointerTo(t_char_, addrspace)); } } - return builder_->CreateCall(f, arg_value); } else if (op->op.same_as(builtin::bitwise_and())) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 15cc4450901e..17de699e00b4 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -266,7 +266,6 @@ TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimEx ICHECK_EQ(call->args.size(), 1); Array cargs; cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); - cargs.push_back(IntImm(DataType::UInt(32), 2)); cargs.push_back(call->args[0]); cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef // LLVM requires that the return type must match the first argument type diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index 4b64e92127d3..aa4f68d0b090 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -26,11 +26,14 @@ #ifdef TVM_LLVM_VERSION +#include #include #include #include #include +#include "llvm_instance.h" + namespace tvm { namespace codegen { // num_signature means number of arguments used to query signature @@ -41,7 +44,9 @@ inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); - cargs.push_back(IntImm(DataType::UInt(32), num_signature)); + ICHECK_EQ(call->args.size(), num_signature) + << "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature + << " arguments, but got " << call->args.size(); for (PrimExpr arg : call->args) { cargs.push_back(arg); @@ -56,7 +61,9 @@ inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); - cargs.push_back(IntImm(DataType::UInt(32), num_signature)); + ICHECK_EQ(call->args.size(), num_signature) + << "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature + << " arguments, but got " << call->args.size(); for (PrimExpr arg : call->args) { cargs.push_back(arg); } diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index f2468a8ef99f..a68637cc844e 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -51,6 +51,21 @@ #define llvmGetPointerTo(arg, offset) (arg->getPointerTo(offset)) #endif +#if TVM_LLVM_VERSION >= 130 +#define llvmGetIntrinName(id) \ + std::string(llvm::Intrinsic::getBaseName(static_cast(id))) +#elif TVM_LLVM_VERSION >= 40 +// This is the version of Intrinsic::getName that works for overloaded +// intrinsics. Helpfully, if we provide no types to this function, it +// will give us the overloaded name without the types appended. This +// should be enough information for most uses. +#define llvmGetIntrinName(id) \ + std::string(llvm::Intrinsic::getName(static_cast(id), {})) +#else +// Nothing to do, just return the intrinsic id number +#define llvmGetIntrinName(id) std::to_string(id) +#endif + namespace llvm { class LLVMContext; class MemoryBuffer; diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 924f520082fd..e5077b904b87 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -653,20 +653,7 @@ static void LLVMReflectionRegister() { #endif }) .def("target.llvm_get_intrinsic_name", - [](int64_t id) -> String { -#if TVM_LLVM_VERSION >= 130 - return std::string(llvm::Intrinsic::getBaseName(static_cast(id))); -#elif TVM_LLVM_VERSION >= 40 - // This is the version of Intrinsic::getName that works for overloaded - // intrinsics. Helpfully, if we provide no types to this function, it - // will give us the overloaded name without the types appended. This - // should be enough information for most uses. - return std::string(llvm::Intrinsic::getName(static_cast(id), {})); -#else - // Nothing to do, just return the intrinsic id number - return std::to_string(id); -#endif - }) + [](int64_t id) -> String { return llvmGetIntrinName(id); }) .def("target.llvm_get_system_x86_vendor", []() -> String { #if TVM_LLVM_VERSION >= 120 diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 705739d98ba9..8e350924501e 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -558,20 +558,16 @@ class Vectorizer : public StmtMutator, public ExprFunctor new_args; if (op->op.same_as(builtin::call_llvm_pure_intrin())) { // op->args[1], will give us total number of arguments to intrinsic - int num_signature = Downcast(op->args[1])->value; Array op_expr_args; - for (int i = 0; i < num_signature; i++) { + for (size_t i = 1; i < op->args.size(); ++i) { // Collect all intrinsic arguments - op_expr_args.push_back(op->args[i + 2]); + op_expr_args.push_back(op->args[i]); } // Generate RAMP nodes for intrinsic arguments Array updated_args = MutateArray(op_expr_args, &lane); - // Collect Intrinsic ID and no. of argument - for (int i = 0; i < 2; i++) { - new_args.push_back(op->args[i]); - } + new_args.push_back(op->args[0]); // Collect updated intrinsic arguments - for (int i = 0; i < num_signature; i++) { + for (size_t i = 0; i < updated_args.size(); ++i) { new_args.push_back(updated_args[i]); } } else { diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 2105a2a2c31b..15c030aeacf2 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -35,7 +35,7 @@ def test_llvm_intrin(): n = tvm.runtime.convert(4) A = ib.pointer("float32", name="A") args = [tvm.tir.call_intrin("handle", "tir.address_of", A[0]), 0, 3, 1] - ib.emit(tvm.tir.Evaluate(tvm.tir.Call("int32", "tir.prefetch", args))) + ib.emit(tvm.tir.Evaluate(tvm.tir.Call("void", "tir.prefetch", args))) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "prefetch")) @@ -47,7 +47,7 @@ def test_llvm_void_intrin(): ib = tvm.tir.ir_builder.create() A = ib.pointer("uint8", name="A") # Create an intrinsic that returns void. - x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, "uint32"), A.asobject().data) + x = tvm.tir.call_llvm_intrin("", "llvm.assume", tvm.tir.const(1, "int1")) ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) @@ -72,9 +72,7 @@ def test_llvm_overloaded_intrin(): def use_llvm_intrinsic(A, C): ib = tvm.tir.ir_builder.create() L = A.vload((0, 0)) - I = tvm.tir.call_llvm_pure_intrin( - "int32", "llvm.ctlz", tvm.tir.const(2, "uint32"), L, tvm.tir.const(0, "int1") - ) + I = tvm.tir.call_llvm_pure_intrin("int32", "llvm.ctlz", L, tvm.tir.const(0, "int1")) S = C.vstore((0, 0), I) ib.emit(S) return ib.get() diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index fe7a615531ae..c0e868cac065 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -206,7 +206,6 @@ def conv2d_async_non_contig( B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32") C[0:32] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - T.uint32(3), C[0:32], B_i32x32, A_i32, @@ -240,7 +239,6 @@ def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: c_buffer[vn_index, x] = 0 c_buffer[vn_index, T.ramp(0, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - T.uint32(3), c_buffer[vn_index, T.ramp(0, 1, 32)], T.reinterpret(a_buffer[vn_index, T.ramp(0, 1, 128)], dtype="int32x32"), T.reinterpret(w_buffer[vi_index, T.ramp(0, 1, 128)], dtype="int32x32"), @@ -656,7 +654,6 @@ def main( b_i32x32: T.int32x32 = T.reinterpret(b_i8x128, dtype="int32x32") c_buffer[0:32] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - T.uint32(3), c_buffer[0:32], T.broadcast(a_i32, 32), b_i32x32, @@ -811,7 +808,6 @@ def main( b_i32x32: T.int32x32 = T.reinterpret(b_i8x128, dtype="int32x32") c_buffer[0:32] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - T.uint32(3), c_buffer[0:32], T.broadcast(a_i32, 32), b_i32x32, diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index c2a5109aff3c..c7f9d2a00fed 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -294,7 +294,7 @@ def main( # type: ignore b_buffer[0, 0:128], dtype="int32x32" ) # type: ignore c_buffer[0:32] = T.call_llvm_pure_intrin( # type: ignore - 4390, T.uint32(3), c_buffer[0:32], b_i32x32, a_i32, dtype="int32x32" + 4390, c_buffer[0:32], b_i32x32, a_i32, dtype="int32x32" ) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 8f77fa1c4016..43dcc0b70269 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -85,7 +85,6 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vmpybusv.128B"), - T.uint32(2), T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), dtype="int16x128", @@ -108,7 +107,6 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vaddubh.128B"), - T.uint32(2), T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), dtype="int16x128", @@ -131,7 +129,6 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.uint32(2), T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), dtype="int32x32", diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index a584997dd507..56c521d68ced 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -87,7 +87,6 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: vn_ind = T.axis.remap("S", [n]) c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.uint32(2), T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), dtype="int32x32", @@ -124,7 +123,6 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: vn_ind = T.axis.remap("S", [n]) c_buffer[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.uint32(2), T.reinterpret( a_buffer[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], dtype="int32x32" ), @@ -168,7 +166,6 @@ def operator( vn_ind = T.axis.remap("S", [n]) c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.uint32(2), T.reinterpret( a_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], dtype="int32x32", @@ -267,7 +264,6 @@ def operator( vn_ind = T.axis.remap("S", [n]) c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.uint32(2), T.reinterpret( a_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], dtype="int32x32", diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py index 1272b35451f9..313657108c62 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py @@ -236,7 +236,6 @@ def main( C_i32x16 = C.vload([0], dtype="int32x16") C[T.ramp(0, 1, 16)] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), - T.uint32(3), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, diff --git a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py index c3a76e101fa7..637f3093d8e1 100644 --- a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py +++ b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py @@ -1159,7 +1159,7 @@ def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, B_i8x64: T.int8x64 = B[0, 0:64] B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C_i32x16: T.int32x16 = C[0:16] - C[0:16] = T.call_llvm_pure_intrin(T.uint32(intrin_id), T.uint32(3), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") + C[0:16] = T.call_llvm_pure_intrin(T.uint32(intrin_id), C_i32x16, T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16") for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7): for ax4_fused in T.vectorized(16): with T.block("T_cast_8"): diff --git a/tests/python/tir-base/test_tir_ops.py b/tests/python/tir-base/test_tir_ops.py index f2a18aeae519..dfa5cbab80c0 100644 --- a/tests/python/tir-base/test_tir_ops.py +++ b/tests/python/tir-base/test_tir_ops.py @@ -236,9 +236,9 @@ def test_comm_reducer(num_args): def test_llvm_intrin(): with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"): - a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy", 0) + a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy") with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"): - a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy", 0) + a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy") if __name__ == "__main__": diff --git a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py index 63700853b36a..18e16513f481 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py @@ -832,7 +832,7 @@ def single_reduction_loop_with_tensorize( B_i8x128 = B[0, 0:128] B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32") C[0:32] = T.call_llvm_pure_intrin( - 4217, T.uint32(3), C[0:32], T.broadcast(A_i32, 32), B_i32x32, dtype="int32x32" + 4217, C[0:32], T.broadcast(A_i32, 32), B_i32x32, dtype="int32x32" ) diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 13bb1c60cb53..5a4d4ea17d08 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -781,16 +781,14 @@ class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): for j in T.vectorized(extent): - A[j] = T.call_llvm_pure_intrin( - "float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j] - ) + A[j] = T.call_llvm_pure_intrin("float32", "llvm.sqrt", B[j]) @I.ir_module class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin( - vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)] + vec_str, "llvm.sqrt", B[T.Ramp(0, 1, extent)] ) with tvm.target.Target(target): @@ -809,16 +807,14 @@ class Before: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): for j in T.vectorized(extent): - A[j] = T.call_llvm_pure_intrin( - "int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j] - ) + A[j] = T.call_llvm_pure_intrin("int32", "llvm.lround", B[j]) @I.ir_module class After: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin( - vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)] + vec_str, "llvm.lround", B[T.Ramp(0, 1, extent)] ) with pytest.raises(Exception) as e_info: diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 267fae20cab3..be8b03357dde 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -493,10 +493,10 @@ def test_cast(): def test_llvm_intrin_imm(): - a = tir.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0)) - _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))') - a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0)) - _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))') + a = tir.call_llvm_intrin("int32x4", "llvm.donothing") + _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing")') + a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing") + _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing")') def test_binary_arith(): @@ -1034,16 +1034,14 @@ def test_vectorize_llvm_pure_intrin(): def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (4,), "float32") B = T.match_buffer(b, (4,), "float32") - A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin( - "float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)] - ) + A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", B[T.Ramp(0, 1, 4)]) expected_output = """ # from tvm.script import tir as T @T.prim_func def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): - A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4]) + A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", B[0:4]) """ _assert_print(main, expected_output) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 0e1b328844be..2be2e2e98d81 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -375,7 +375,6 @@ def mmult( for x_c in T.serial(0, 32): C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), - T.uint32(3), T.broadcast( A[ ( @@ -393,7 +392,6 @@ def mmult( ) C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), - T.uint32(3), T.broadcast( A[ ( @@ -416,7 +414,6 @@ def mmult( ) C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), - T.uint32(3), T.broadcast( A[ ( @@ -439,7 +436,6 @@ def mmult( ) C_global[T.ramp((x_c * 32), 1, 32)] = T.call_llvm_pure_intrin( T.uint32(97), - T.uint32(3), T.broadcast( A[ ( @@ -3216,7 +3212,6 @@ def ctpop(A: T.Buffer((16,), "uint8"), B: T.Buffer((16,), "uint8")) -> None: ) B[vi] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id("llvm.ctpop.i8"), - T.uint32(1), A[vi], dtype="uint8", )