diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index 4b175fb3ff60..d229491a4c7b 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -48,7 +48,4 @@ TVM_DLL float __gnu_h2f_ieee(uint16_t a) { } #endif - -TVM_DLL uint16_t __truncsfhf2(float v) { return __gnu_f2h_ieee(v); } -TVM_DLL float __extendhfsf2(uint16_t v) { return __gnu_h2f_ieee(v); } } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 305358d079d0..ca9d577f64f6 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -34,6 +34,11 @@ #else #include #endif +#if TVM_LLVM_VERSION >= 60 +#include +#else +#include +#endif #include #include #include @@ -167,6 +172,45 @@ void CodeGenLLVM::InitTarget() { LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name; } } + +#if TVM_LLVM_VERSION >= 60 + bool use_float16_abi = false; +#if TVM_LLVM_VERSION >= 150 + // For conversions between _Float16 and float, LLVM uses runtime functions + // __extendhfsf2 and __truncsfhf2. On X86 up until version 14, LLVM used + // "uint16_t" for representing _Float16. Starting with LLVM 15, half-precision + // values can be passed in XMM registers (i.e. as floating-point). This happens + // when the compilation target has SSE2 enabled (either directly, or by enabling + // a feature that implies SSE2). + // Because the names of the conversion functions remain unchanged, it is impossible + // for TVM to provide them in the runtime, and have them work in both cases. + // To alleviate this issue, emit these functions directly into the target module + // after detecting whether or not to use floating-point ABI. To allow the linker + // to remove potential duplicates (or if they are unused), they are weak and + // reside in a separate section (ELF). + llvm::Triple::ArchType arch_type = tm->getTargetTriple().getArch(); + if (arch_type == llvm::Triple::x86 || arch_type == llvm::Triple::x86_64) { + // Detect if SSE2 is enabled. This determines whether float16 ABI is used. + std::stringstream os; + const char fname[] = "test_sse2"; + os << "target triple = \"" << llvm_target_->GetTargetTriple() << "\"\n" + << "define void @" << fname << "() #0 { ret void } attributes #0 = { \"target-cpu\"=\"" + << llvm_target_->GetCPU() << "\" "; + if (auto&& fs = llvm_target_->GetTargetFeatureString(); !fs.empty()) { + os << "\"target-features\"=\"" << fs << "\" "; + } + os << "}\n"; + auto mod = llvm_target_->GetInstance().ParseIR(os.str()); + auto* test_sse2 = mod->getFunction(fname); + ICHECK_NE(test_sse2, nullptr) << "Module creation error"; + use_float16_abi = tm->getSubtargetImpl(*test_sse2)->checkFeatures("+sse2"); + } +#endif // TVM_LLVM_VERSION >= 150 + + // Call this function only with LLVM >= 6.0. The code it emits uses "dso_local" + // which was introduced in LLVM 6. + EmitFloat16ConversionBuiltins(use_float16_abi); +#endif // TVM_LLVM_VERSION >= 60 } void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } @@ -949,6 +993,189 @@ void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) { } } +void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { + // The LLVM IR for these function was obtained by compiling + // + // For integer ABI: + // __truncXfYf2__(a); + // __extendXfYf2__(a); + // For floating-point ABI: + // __truncXfYf2__(a); + // __extendXfYf2__<_Float16, uint16_t, 10, float, uint32_t, 23>(a); + + static const char trunc_body[] = // __truncsfhf2 + " %v0 = bitcast float %a0 to i32\n" + " %v1 = and i32 %v0, 2147483647\n" + " %v2 = add nsw i32 %v1, -947912704\n" + " %v3 = add nsw i32 %v1, -1199570944\n" + " %v4 = icmp ult i32 %v2, %v3\n" + " br i1 %v4, label %b1, label %b5\n" + "b1:\n" + " %v5 = lshr i32 %v0, 13\n" + " %v6 = and i32 %v5, 65535\n" + " %v7 = add nuw nsw i32 %v6, -114688\n" + " %v8 = and i32 %v0, 8191\n" + " %v9 = icmp ugt i32 %v8, 4096\n" + " br i1 %v9, label %b2, label %b3\n" + "b2:\n" + " %v10 = add nuw nsw i32 %v6, -114687\n" + " br label %b13\n" + "b3:\n" + " %v11 = icmp eq i32 %v8, 4096\n" + " br i1 %v11, label %b4, label %b13\n" + "b4:\n" + " %v12 = and i32 %v7, 65535\n" + " %v13 = and i32 %v5, 1\n" + " %v14 = add nuw nsw i32 %v12, %v13\n" + " br label %b13\n" + "b5:\n" + " %v15 = icmp ugt i32 %v1, 2139095040\n" + " br i1 %v15, label %b6, label %b7\n" + "b6:\n" + " %v16 = lshr i32 %v0, 13\n" + " %v17 = and i32 %v16, 511\n" + " %v18 = or i32 %v17, 32256\n" + " br label %b13\n" + "b7:\n" + " %v19 = icmp ugt i32 %v1, 1199570943\n" + " br i1 %v19, label %b13, label %b8\n" + "b8:\n" + " %v20 = icmp ult i32 %v1, 754974720\n" + " br i1 %v20, label %b13, label %b9\n" + "b9:\n" + " %v21 = lshr i32 %v1, 23\n" + " %v22 = sub nsw i32 113, %v21\n" + " %v23 = and i32 %v0, 8388607\n" + " %v24 = or i32 %v23, 8388608\n" + " %v25 = add nsw i32 %v21, -81\n" + " %v26 = shl i32 %v24, %v25\n" + " %v27 = icmp ne i32 %v26, 0\n" + " %v28 = lshr i32 %v24, %v22\n" + " %v29 = zext i1 %v27 to i32\n" + " %v30 = lshr i32 %v28, 13\n" + " %v31 = and i32 %v28, 8191\n" + " %v32 = or i32 %v31, %v29\n" + " %v33 = icmp ugt i32 %v32, 4096\n" + " br i1 %v33, label %b10, label %b11\n" + "b10:\n" + " %v34 = add nuw nsw i32 %v30, 1\n" + " br label %b13\n" + "b11:\n" + " %v35 = icmp eq i32 %v32, 4096\n" + " br i1 %v35, label %b12, label %b13\n" + "b12:\n" + " %v36 = and i32 %v30, 1\n" + " %v37 = add nuw nsw i32 %v36, %v30\n" + " br label %b13\n" + "b13:\n" + " %v38 = phi i32 [ %v18, %b6 ], [ %v10, %b2 ], [ %v14, %b4 ], [ %v7, %b3 ],\n" + " [ 31744, %b7 ], [ 0, %b8 ], [ %v34, %b10 ], [ %v37, %b12 ],\n" + " [ %v30, %b11 ]\n" + " %v39 = lshr i32 %v0, 16\n" + " %v40 = and i32 %v39, 32768\n" + " %v41 = or i32 %v38, %v40\n" + " %vlast = trunc i32 %v41 to i16\n"; + + static const char extend_body[] = // __extendhfsf2 + " %v1 = and i16 %vinp, 32767\n" + " %v2 = zext i16 %v1 to i32\n" + " %v3 = add nsw i16 %v1, -1024\n" + " %v4 = icmp ult i16 %v3, 30720\n" + " br i1 %v4, label %b1, label %b2\n" + "b1:\n" + " %v5 = shl nuw nsw i32 %v2, 13\n" + " %v6 = add nuw nsw i32 %v5, 939524096\n" + " br label %b6\n" + "b2:\n" + " %v7 = icmp ugt i16 %v1, 31743\n" + " br i1 %v7, label %b3, label %b4\n" + "b3:\n" + " %v8 = shl nuw nsw i32 %v2, 13\n" + " %v9 = or i32 %v8, 2139095040\n" + " br label %b6\n" + "b4:\n" + " %v10 = icmp eq i16 %v1, 0\n" + " br i1 %v10, label %b6, label %b5\n" + "b5:\n" + " %v11 = icmp ult i16 %v1, 256\n" + " %v12 = lshr i32 %v2, 8\n" + " %v13 = select i1 %v11, i32 %v2, i32 %v12\n" + " %v14 = select i1 %v11, i32 32, i32 24\n" + " %v15 = icmp ult i32 %v13, 16\n" + " %v16 = lshr i32 %v13, 4\n" + " %v17 = add nsw i32 %v14, -4\n" + " %v18 = select i1 %v15, i32 %v13, i32 %v16\n" + " %v19 = select i1 %v15, i32 %v14, i32 %v17\n" + " %v20 = icmp ult i32 %v18, 4\n" + " %v21 = lshr i32 %v18, 2\n" + " %v22 = add nsw i32 %v19, -2\n" + " %v23 = select i1 %v20, i32 %v18, i32 %v21\n" + " %v24 = select i1 %v20, i32 %v19, i32 %v22\n" + " %v25 = icmp ult i32 %v23, 2\n" + " %v26 = sub nsw i32 0, %v23\n" + " %v27 = select i1 %v25, i32 %v26, i32 -2\n" + " %v28 = add nsw i32 %v27, %v24\n" + " %v29 = add nsw i32 %v28, -8\n" + " %v30 = shl i32 %v2, %v29\n" + " %v31 = xor i32 %v30, 8388608\n" + " %v32 = shl i32 %v28, 23\n" + " %v33 = sub i32 1124073472, %v32\n" + " %v34 = or i32 %v31, %v33\n" + " br label %b6\n" + "b6:\n" + " %v35 = phi i32 [ %v6, %b1 ], [ %v9, %b3 ], [ %v34, %b5 ], [ 0, %b4 ]\n" + " %v36 = and i16 %vinp, -32768\n" + " %v37 = zext i16 %v36 to i32\n" + " %v38 = shl nuw i32 %v37, 16\n" + " %v39 = or i32 %v35, %v38\n" + " %v40 = bitcast i32 %v39 to float\n" + " ret float %v40\n" + "}\n"; + + std::string short_type = use_float16_abi ? "half" : "i16"; + + std::string short_cast_in, short_cast_out; + if (use_float16_abi) { + short_cast_in = " %vinp = bitcast half %a0 to i16\n"; + short_cast_out = " %vres = bitcast i16 %vlast to half\n"; + } else { + // No-ops that preserve the i16 values. + short_cast_in = " %vinp = add i16 %a0, 0\n"; + short_cast_out = " %vres = add i16 %vlast, 0\n"; + } + + llvm::Triple triple(llvm_target_->GetTargetTriple()); + + static const char elf_section_name[] = ".text.tvm.fp16.conv"; + std::string section = triple.getObjectFormat() == llvm::Triple::ELF + ? std::string("section \"") + elf_section_name + "\" " + : ""; + + std::string trunc_header = "define weak dso_local " + short_type + + " @__truncsfhf2(float %a0) local_unnamed_addr #0 " + section + + "{\nb0:\n"; + std::string trunc_return = " ret " + short_type + " %vres\n}\n"; + + std::string extend_header = "define weak dso_local float @__extendhfsf2(" + short_type + + " %a0) local_unnamed_addr #0 " + section + "{\nb0:\n"; + + // truncate = trunc_header + trunc_body + short_cast_out + trunc_return + // extend = extend_header + short_cast_in + extend_body + + std::string attributes = "attributes #0 = { nounwind readnone \"target-cpu\"=\"" + + llvm_target_->GetCPU() + "\" \"target-features\"=\"" + + llvm_target_->GetTargetFeatureString() + "\" }\n"; + + auto data_layout = llvm_target_->GetOrCreateTargetMachine()->createDataLayout(); + std::string module_ir = "target triple = \"" + llvm_target_->GetTargetTriple() + "\"\n" + + "target datalayout = \"" + data_layout.getStringRepresentation() + + "\"\n" + trunc_header + trunc_body + short_cast_out + trunc_return + + extend_header + short_cast_in + extend_body + attributes; + + auto builtins_module = llvm_target_->GetInstance().ParseIR(module_ir); + link_modules_.push_back(std::move(builtins_module)); +} + llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { ICHECK_GE(op->args.size(), 2U); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e6321be647aa..7a8daf2e761f 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -395,6 +395,14 @@ class CodeGenLLVM : public ExprFunctor, * \param func The function to set attributes on. */ void SetTargetAttributes(llvm::Function* func); + /*! + * \brief Emit LLVM IR for conversion functions __extendhfsf2 and __truncsfhf2 + * into the current llvm::Module. + * + * \param use_float16_abi Whether to use floating-point or integer ABI. + */ + void EmitFloat16ConversionBuiltins(bool use_float16_abi); + /*! * \brief Get the number of elements in the given vector value. * \param vec The value, must be of a vector type. diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index c57648382827..e179d17101a3 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -18,11 +18,11 @@ import ctypes import json import math +import numpy as np +import pytest import re import sys -import numpy as np -import pytest import tvm import tvm.testing from tvm import te @@ -854,7 +854,8 @@ def make_call_extern(caller, callee): } mod = tvm.IRModule(functions=functions) ir_text = tvm.build(mod, None, target="llvm").get_source("ll") - matches = re.findall(r"^define[^@]*@([a-zA-Z_][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) + # Skip functions whose names start with _. + matches = re.findall(r"^define[^@]*@([a-zA-Z][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) assert matches == sorted(matches) diff --git a/tests/python/unittest/test_target_codegen_x86.py b/tests/python/unittest/test_target_codegen_x86.py index ec42e0a4d749..af91ed4520fd 100644 --- a/tests/python/unittest/test_target_codegen_x86.py +++ b/tests/python/unittest/test_target_codegen_x86.py @@ -14,27 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te +import numpy as np +import platform +import pytest import re +import textwrap +import tvm +from tvm import te -def test_fp16_to_fp32(): - if tvm.target.codegen.llvm_version_major() < 6: - print( - "Skipping due to LLVM version being {} < 6".format( - tvm.target.codegen.llvm_version_major() - ) - ) - return +llvm_version = tvm.target.codegen.llvm_version_major() +machine = platform.machine() - import platform +if machine not in ["i386", "x86_64", "AMD64", "amd64"]: + pytest.skip(f"Requires x86_64/i386, but machine is {machine}", allow_module_level=True) - machine = platform.machine() - if machine not in ["x86_64", "i386", "AMD64"]: - print("Skipping test because the platform is: {} ".format(machine)) - return +@tvm.testing.requires_llvm +@pytest.mark.skipif(llvm_version < 6, reason=f"Requires LLVM 6+, got {llvm_version}") +def test_fp16_to_fp32(): def fp16_to_fp32(target, width, match=None, not_match=None): elements = 64 n = tvm.runtime.convert(elements) @@ -63,5 +61,51 @@ def fp16_to_fp32(target, width, match=None, not_match=None): fp16_to_fp32("llvm", 9, not_match="vcvtph2ps") +is_32bit = platform.architecture()[0] == "32bit" + + +@tvm.testing.requires_llvm +@pytest.mark.skipif(is_32bit, reason=f"Fails in CI due to architecture mismatch in JIT") +@pytest.mark.parametrize("feature_string", ["-sse2", "+sse2"]) +def test_fp16_fp32_conversions(feature_string): + relay_model = textwrap.dedent( + """ + #[version = "0.0.5"] + def @main(%inp : Tensor[(3), float32], %cst : Tensor[(3), float32]) { + %1 = cast(%inp, dtype="float16"); + %2 = cast(%cst, dtype="float16"); + %3 = add(%1, %2); + %4 = cast(%3, dtype="float32"); + %4 + } + """ + ) + + ir_mod = tvm.parser.fromtext(relay_model) + + arch = "i386" if machine == "i386" else "x86_64" + aot_factory = tvm.relay.build( + ir_mod, + params={"cst": np.array([1.0, 2.0, 3.0], dtype="float32")}, + target=f"llvm --mtriple={arch} --mattr={feature_string}", + executor=tvm.relay.backend.Executor( + "aot", {"interface-api": "packed", "unpacked-api": False} + ), + ) + + mod_name = aot_factory["list_module_names"]()[0] + executor = aot_factory[mod_name] + mod = executor(tvm.cpu(0)) + + inp = tvm.nd.array(np.array([1.1, 2.1, 3.1], dtype="float32"), device=tvm.cpu(0)) + + mod.get_function("set_input")(0, inp) + mod.get_function("run")() + out = mod.get_function("get_output")(0) + + expected = np.array([2.1, 4.1, 6.1], dtype="float32") + np.testing.assert_allclose(out.asnumpy(), expected, rtol=1e-3) + + if __name__ == "__main__": test_fp16_to_fp32()