From 7523d15f88eb091ee23b01585858677150833ceb Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 22 Sep 2022 13:00:55 -0700 Subject: [PATCH] [LLVM] Emit fp16/fp32 builtins directly into target module For conversions between `_Float16` and `float`, LLVM uses runtime functions `__extendhfsf2` and `__truncsfhf2`. On X86 up until version 14, LLVM used `uint16_t` for representing `_Float16`. Starting with LLVM 15, half- precision values can be passed in XMM registers (i.e. as floating-point). This happens when the compilation target has SSE2 enabled (either directly, or by enabling a feature that implies SSE2). Because the names of the conversion functions remain unchanged, it is impossible for TVM to provide them in the runtime, and have them work in both cases. To solve this issue, emit these functions directly into the target module after detecting whether or not to use floating-point ABI. To allow the linker to remove potential duplicates (or if they are unused), they are weak and reside in a separate section. --- src/runtime/builtin_fp16.cc | 3 - src/target/llvm/codegen_llvm.cc | 227 ++++++++++++++++++ src/target/llvm/codegen_llvm.h | 8 + .../unittest/test_target_codegen_llvm.py | 7 +- .../unittest/test_target_codegen_x86.py | 74 ++++-- 5 files changed, 298 insertions(+), 21 deletions(-) diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index 4b175fb3ff60..d229491a4c7b 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -48,7 +48,4 @@ TVM_DLL float __gnu_h2f_ieee(uint16_t a) { } #endif - -TVM_DLL uint16_t __truncsfhf2(float v) { return __gnu_f2h_ieee(v); } -TVM_DLL float __extendhfsf2(uint16_t v) { return __gnu_h2f_ieee(v); } } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 305358d079d0..ca9d577f64f6 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -34,6 +34,11 @@ #else #include #endif +#if TVM_LLVM_VERSION >= 60 +#include +#else +#include +#endif #include #include #include @@ -167,6 +172,45 @@ void CodeGenLLVM::InitTarget() { LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name; } } + +#if TVM_LLVM_VERSION >= 60 + bool use_float16_abi = false; +#if TVM_LLVM_VERSION >= 150 + // For conversions between _Float16 and float, LLVM uses runtime functions + // __extendhfsf2 and __truncsfhf2. On X86 up until version 14, LLVM used + // "uint16_t" for representing _Float16. Starting with LLVM 15, half-precision + // values can be passed in XMM registers (i.e. as floating-point). This happens + // when the compilation target has SSE2 enabled (either directly, or by enabling + // a feature that implies SSE2). + // Because the names of the conversion functions remain unchanged, it is impossible + // for TVM to provide them in the runtime, and have them work in both cases. + // To alleviate this issue, emit these functions directly into the target module + // after detecting whether or not to use floating-point ABI. To allow the linker + // to remove potential duplicates (or if they are unused), they are weak and + // reside in a separate section (ELF). + llvm::Triple::ArchType arch_type = tm->getTargetTriple().getArch(); + if (arch_type == llvm::Triple::x86 || arch_type == llvm::Triple::x86_64) { + // Detect if SSE2 is enabled. This determines whether float16 ABI is used. + std::stringstream os; + const char fname[] = "test_sse2"; + os << "target triple = \"" << llvm_target_->GetTargetTriple() << "\"\n" + << "define void @" << fname << "() #0 { ret void } attributes #0 = { \"target-cpu\"=\"" + << llvm_target_->GetCPU() << "\" "; + if (auto&& fs = llvm_target_->GetTargetFeatureString(); !fs.empty()) { + os << "\"target-features\"=\"" << fs << "\" "; + } + os << "}\n"; + auto mod = llvm_target_->GetInstance().ParseIR(os.str()); + auto* test_sse2 = mod->getFunction(fname); + ICHECK_NE(test_sse2, nullptr) << "Module creation error"; + use_float16_abi = tm->getSubtargetImpl(*test_sse2)->checkFeatures("+sse2"); + } +#endif // TVM_LLVM_VERSION >= 150 + + // Call this function only with LLVM >= 6.0. The code it emits uses "dso_local" + // which was introduced in LLVM 6. + EmitFloat16ConversionBuiltins(use_float16_abi); +#endif // TVM_LLVM_VERSION >= 60 } void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } @@ -949,6 +993,189 @@ void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) { } } +void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { + // The LLVM IR for these function was obtained by compiling + // + // For integer ABI: + // __truncXfYf2__(a); + // __extendXfYf2__(a); + // For floating-point ABI: + // __truncXfYf2__(a); + // __extendXfYf2__<_Float16, uint16_t, 10, float, uint32_t, 23>(a); + + static const char trunc_body[] = // __truncsfhf2 + " %v0 = bitcast float %a0 to i32\n" + " %v1 = and i32 %v0, 2147483647\n" + " %v2 = add nsw i32 %v1, -947912704\n" + " %v3 = add nsw i32 %v1, -1199570944\n" + " %v4 = icmp ult i32 %v2, %v3\n" + " br i1 %v4, label %b1, label %b5\n" + "b1:\n" + " %v5 = lshr i32 %v0, 13\n" + " %v6 = and i32 %v5, 65535\n" + " %v7 = add nuw nsw i32 %v6, -114688\n" + " %v8 = and i32 %v0, 8191\n" + " %v9 = icmp ugt i32 %v8, 4096\n" + " br i1 %v9, label %b2, label %b3\n" + "b2:\n" + " %v10 = add nuw nsw i32 %v6, -114687\n" + " br label %b13\n" + "b3:\n" + " %v11 = icmp eq i32 %v8, 4096\n" + " br i1 %v11, label %b4, label %b13\n" + "b4:\n" + " %v12 = and i32 %v7, 65535\n" + " %v13 = and i32 %v5, 1\n" + " %v14 = add nuw nsw i32 %v12, %v13\n" + " br label %b13\n" + "b5:\n" + " %v15 = icmp ugt i32 %v1, 2139095040\n" + " br i1 %v15, label %b6, label %b7\n" + "b6:\n" + " %v16 = lshr i32 %v0, 13\n" + " %v17 = and i32 %v16, 511\n" + " %v18 = or i32 %v17, 32256\n" + " br label %b13\n" + "b7:\n" + " %v19 = icmp ugt i32 %v1, 1199570943\n" + " br i1 %v19, label %b13, label %b8\n" + "b8:\n" + " %v20 = icmp ult i32 %v1, 754974720\n" + " br i1 %v20, label %b13, label %b9\n" + "b9:\n" + " %v21 = lshr i32 %v1, 23\n" + " %v22 = sub nsw i32 113, %v21\n" + " %v23 = and i32 %v0, 8388607\n" + " %v24 = or i32 %v23, 8388608\n" + " %v25 = add nsw i32 %v21, -81\n" + " %v26 = shl i32 %v24, %v25\n" + " %v27 = icmp ne i32 %v26, 0\n" + " %v28 = lshr i32 %v24, %v22\n" + " %v29 = zext i1 %v27 to i32\n" + " %v30 = lshr i32 %v28, 13\n" + " %v31 = and i32 %v28, 8191\n" + " %v32 = or i32 %v31, %v29\n" + " %v33 = icmp ugt i32 %v32, 4096\n" + " br i1 %v33, label %b10, label %b11\n" + "b10:\n" + " %v34 = add nuw nsw i32 %v30, 1\n" + " br label %b13\n" + "b11:\n" + " %v35 = icmp eq i32 %v32, 4096\n" + " br i1 %v35, label %b12, label %b13\n" + "b12:\n" + " %v36 = and i32 %v30, 1\n" + " %v37 = add nuw nsw i32 %v36, %v30\n" + " br label %b13\n" + "b13:\n" + " %v38 = phi i32 [ %v18, %b6 ], [ %v10, %b2 ], [ %v14, %b4 ], [ %v7, %b3 ],\n" + " [ 31744, %b7 ], [ 0, %b8 ], [ %v34, %b10 ], [ %v37, %b12 ],\n" + " [ %v30, %b11 ]\n" + " %v39 = lshr i32 %v0, 16\n" + " %v40 = and i32 %v39, 32768\n" + " %v41 = or i32 %v38, %v40\n" + " %vlast = trunc i32 %v41 to i16\n"; + + static const char extend_body[] = // __extendhfsf2 + " %v1 = and i16 %vinp, 32767\n" + " %v2 = zext i16 %v1 to i32\n" + " %v3 = add nsw i16 %v1, -1024\n" + " %v4 = icmp ult i16 %v3, 30720\n" + " br i1 %v4, label %b1, label %b2\n" + "b1:\n" + " %v5 = shl nuw nsw i32 %v2, 13\n" + " %v6 = add nuw nsw i32 %v5, 939524096\n" + " br label %b6\n" + "b2:\n" + " %v7 = icmp ugt i16 %v1, 31743\n" + " br i1 %v7, label %b3, label %b4\n" + "b3:\n" + " %v8 = shl nuw nsw i32 %v2, 13\n" + " %v9 = or i32 %v8, 2139095040\n" + " br label %b6\n" + "b4:\n" + " %v10 = icmp eq i16 %v1, 0\n" + " br i1 %v10, label %b6, label %b5\n" + "b5:\n" + " %v11 = icmp ult i16 %v1, 256\n" + " %v12 = lshr i32 %v2, 8\n" + " %v13 = select i1 %v11, i32 %v2, i32 %v12\n" + " %v14 = select i1 %v11, i32 32, i32 24\n" + " %v15 = icmp ult i32 %v13, 16\n" + " %v16 = lshr i32 %v13, 4\n" + " %v17 = add nsw i32 %v14, -4\n" + " %v18 = select i1 %v15, i32 %v13, i32 %v16\n" + " %v19 = select i1 %v15, i32 %v14, i32 %v17\n" + " %v20 = icmp ult i32 %v18, 4\n" + " %v21 = lshr i32 %v18, 2\n" + " %v22 = add nsw i32 %v19, -2\n" + " %v23 = select i1 %v20, i32 %v18, i32 %v21\n" + " %v24 = select i1 %v20, i32 %v19, i32 %v22\n" + " %v25 = icmp ult i32 %v23, 2\n" + " %v26 = sub nsw i32 0, %v23\n" + " %v27 = select i1 %v25, i32 %v26, i32 -2\n" + " %v28 = add nsw i32 %v27, %v24\n" + " %v29 = add nsw i32 %v28, -8\n" + " %v30 = shl i32 %v2, %v29\n" + " %v31 = xor i32 %v30, 8388608\n" + " %v32 = shl i32 %v28, 23\n" + " %v33 = sub i32 1124073472, %v32\n" + " %v34 = or i32 %v31, %v33\n" + " br label %b6\n" + "b6:\n" + " %v35 = phi i32 [ %v6, %b1 ], [ %v9, %b3 ], [ %v34, %b5 ], [ 0, %b4 ]\n" + " %v36 = and i16 %vinp, -32768\n" + " %v37 = zext i16 %v36 to i32\n" + " %v38 = shl nuw i32 %v37, 16\n" + " %v39 = or i32 %v35, %v38\n" + " %v40 = bitcast i32 %v39 to float\n" + " ret float %v40\n" + "}\n"; + + std::string short_type = use_float16_abi ? "half" : "i16"; + + std::string short_cast_in, short_cast_out; + if (use_float16_abi) { + short_cast_in = " %vinp = bitcast half %a0 to i16\n"; + short_cast_out = " %vres = bitcast i16 %vlast to half\n"; + } else { + // No-ops that preserve the i16 values. + short_cast_in = " %vinp = add i16 %a0, 0\n"; + short_cast_out = " %vres = add i16 %vlast, 0\n"; + } + + llvm::Triple triple(llvm_target_->GetTargetTriple()); + + static const char elf_section_name[] = ".text.tvm.fp16.conv"; + std::string section = triple.getObjectFormat() == llvm::Triple::ELF + ? std::string("section \"") + elf_section_name + "\" " + : ""; + + std::string trunc_header = "define weak dso_local " + short_type + + " @__truncsfhf2(float %a0) local_unnamed_addr #0 " + section + + "{\nb0:\n"; + std::string trunc_return = " ret " + short_type + " %vres\n}\n"; + + std::string extend_header = "define weak dso_local float @__extendhfsf2(" + short_type + + " %a0) local_unnamed_addr #0 " + section + "{\nb0:\n"; + + // truncate = trunc_header + trunc_body + short_cast_out + trunc_return + // extend = extend_header + short_cast_in + extend_body + + std::string attributes = "attributes #0 = { nounwind readnone \"target-cpu\"=\"" + + llvm_target_->GetCPU() + "\" \"target-features\"=\"" + + llvm_target_->GetTargetFeatureString() + "\" }\n"; + + auto data_layout = llvm_target_->GetOrCreateTargetMachine()->createDataLayout(); + std::string module_ir = "target triple = \"" + llvm_target_->GetTargetTriple() + "\"\n" + + "target datalayout = \"" + data_layout.getStringRepresentation() + + "\"\n" + trunc_header + trunc_body + short_cast_out + trunc_return + + extend_header + short_cast_in + extend_body + attributes; + + auto builtins_module = llvm_target_->GetInstance().ParseIR(module_ir); + link_modules_.push_back(std::move(builtins_module)); +} + llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { ICHECK_GE(op->args.size(), 2U); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e6321be647aa..7a8daf2e761f 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -395,6 +395,14 @@ class CodeGenLLVM : public ExprFunctor, * \param func The function to set attributes on. */ void SetTargetAttributes(llvm::Function* func); + /*! + * \brief Emit LLVM IR for conversion functions __extendhfsf2 and __truncsfhf2 + * into the current llvm::Module. + * + * \param use_float16_abi Whether to use floating-point or integer ABI. + */ + void EmitFloat16ConversionBuiltins(bool use_float16_abi); + /*! * \brief Get the number of elements in the given vector value. * \param vec The value, must be of a vector type. diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index c57648382827..e179d17101a3 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -18,11 +18,11 @@ import ctypes import json import math +import numpy as np +import pytest import re import sys -import numpy as np -import pytest import tvm import tvm.testing from tvm import te @@ -854,7 +854,8 @@ def make_call_extern(caller, callee): } mod = tvm.IRModule(functions=functions) ir_text = tvm.build(mod, None, target="llvm").get_source("ll") - matches = re.findall(r"^define[^@]*@([a-zA-Z_][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) + # Skip functions whose names start with _. + matches = re.findall(r"^define[^@]*@([a-zA-Z][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) assert matches == sorted(matches) diff --git a/tests/python/unittest/test_target_codegen_x86.py b/tests/python/unittest/test_target_codegen_x86.py index ec42e0a4d749..af91ed4520fd 100644 --- a/tests/python/unittest/test_target_codegen_x86.py +++ b/tests/python/unittest/test_target_codegen_x86.py @@ -14,27 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te +import numpy as np +import platform +import pytest import re +import textwrap +import tvm +from tvm import te -def test_fp16_to_fp32(): - if tvm.target.codegen.llvm_version_major() < 6: - print( - "Skipping due to LLVM version being {} < 6".format( - tvm.target.codegen.llvm_version_major() - ) - ) - return +llvm_version = tvm.target.codegen.llvm_version_major() +machine = platform.machine() - import platform +if machine not in ["i386", "x86_64", "AMD64", "amd64"]: + pytest.skip(f"Requires x86_64/i386, but machine is {machine}", allow_module_level=True) - machine = platform.machine() - if machine not in ["x86_64", "i386", "AMD64"]: - print("Skipping test because the platform is: {} ".format(machine)) - return +@tvm.testing.requires_llvm +@pytest.mark.skipif(llvm_version < 6, reason=f"Requires LLVM 6+, got {llvm_version}") +def test_fp16_to_fp32(): def fp16_to_fp32(target, width, match=None, not_match=None): elements = 64 n = tvm.runtime.convert(elements) @@ -63,5 +61,51 @@ def fp16_to_fp32(target, width, match=None, not_match=None): fp16_to_fp32("llvm", 9, not_match="vcvtph2ps") +is_32bit = platform.architecture()[0] == "32bit" + + +@tvm.testing.requires_llvm +@pytest.mark.skipif(is_32bit, reason=f"Fails in CI due to architecture mismatch in JIT") +@pytest.mark.parametrize("feature_string", ["-sse2", "+sse2"]) +def test_fp16_fp32_conversions(feature_string): + relay_model = textwrap.dedent( + """ + #[version = "0.0.5"] + def @main(%inp : Tensor[(3), float32], %cst : Tensor[(3), float32]) { + %1 = cast(%inp, dtype="float16"); + %2 = cast(%cst, dtype="float16"); + %3 = add(%1, %2); + %4 = cast(%3, dtype="float32"); + %4 + } + """ + ) + + ir_mod = tvm.parser.fromtext(relay_model) + + arch = "i386" if machine == "i386" else "x86_64" + aot_factory = tvm.relay.build( + ir_mod, + params={"cst": np.array([1.0, 2.0, 3.0], dtype="float32")}, + target=f"llvm --mtriple={arch} --mattr={feature_string}", + executor=tvm.relay.backend.Executor( + "aot", {"interface-api": "packed", "unpacked-api": False} + ), + ) + + mod_name = aot_factory["list_module_names"]()[0] + executor = aot_factory[mod_name] + mod = executor(tvm.cpu(0)) + + inp = tvm.nd.array(np.array([1.1, 2.1, 3.1], dtype="float32"), device=tvm.cpu(0)) + + mod.get_function("set_input")(0, inp) + mod.get_function("run")() + out = mod.get_function("get_output")(0) + + expected = np.array([2.1, 4.1, 6.1], dtype="float32") + np.testing.assert_allclose(out.asnumpy(), expected, rtol=1e-3) + + if __name__ == "__main__": test_fp16_to_fp32()