Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/runtime/builtin_fp16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,4 @@ TVM_DLL float __gnu_h2f_ieee(uint16_t a) {
}

#endif

TVM_DLL uint16_t __truncsfhf2(float v) { return __gnu_f2h_ieee(v); }
TVM_DLL float __extendhfsf2(uint16_t v) { return __gnu_h2f_ieee(v); }
}
227 changes: 227 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
#else
#include <llvm/Support/Dwarf.h>
#endif
#if TVM_LLVM_VERSION >= 60
#include <llvm/CodeGen/TargetSubtargetInfo.h>
#else
#include <llvm/Target/TargetSubtargetInfo.h>
#endif
#include <llvm/IR/Argument.h>
#include <llvm/IR/Attributes.h>
#include <llvm/IR/BasicBlock.h>
Expand Down Expand Up @@ -167,6 +172,45 @@ void CodeGenLLVM::InitTarget() {
LOG(WARNING) << "Set native vector bits to be 128 for " << arch_name;
}
}

#if TVM_LLVM_VERSION >= 60
bool use_float16_abi = false;
#if TVM_LLVM_VERSION >= 150
// For conversions between _Float16 and float, LLVM uses runtime functions
// __extendhfsf2 and __truncsfhf2. On X86 up until version 14, LLVM used
// "uint16_t" for representing _Float16. Starting with LLVM 15, half-precision
// values can be passed in XMM registers (i.e. as floating-point). This happens
// when the compilation target has SSE2 enabled (either directly, or by enabling
// a feature that implies SSE2).
// Because the names of the conversion functions remain unchanged, it is impossible
// for TVM to provide them in the runtime, and have them work in both cases.
// To alleviate this issue, emit these functions directly into the target module
// after detecting whether or not to use floating-point ABI. To allow the linker
// to remove potential duplicates (or if they are unused), they are weak and
// reside in a separate section (ELF).
llvm::Triple::ArchType arch_type = tm->getTargetTriple().getArch();
if (arch_type == llvm::Triple::x86 || arch_type == llvm::Triple::x86_64) {
// Detect if SSE2 is enabled. This determines whether float16 ABI is used.
std::stringstream os;
const char fname[] = "test_sse2";
os << "target triple = \"" << llvm_target_->GetTargetTriple() << "\"\n"
<< "define void @" << fname << "() #0 { ret void } attributes #0 = { \"target-cpu\"=\""
<< llvm_target_->GetCPU() << "\" ";
if (auto&& fs = llvm_target_->GetTargetFeatureString(); !fs.empty()) {
os << "\"target-features\"=\"" << fs << "\" ";
}
os << "}\n";
auto mod = llvm_target_->GetInstance().ParseIR(os.str());
auto* test_sse2 = mod->getFunction(fname);
ICHECK_NE(test_sse2, nullptr) << "Module creation error";
use_float16_abi = tm->getSubtargetImpl(*test_sse2)->checkFeatures("+sse2");
}
#endif // TVM_LLVM_VERSION >= 150

// Call this function only with LLVM >= 6.0. The code it emits uses "dso_local"
// which was introduced in LLVM 6.
EmitFloat16ConversionBuiltins(use_float16_abi);
#endif // TVM_LLVM_VERSION >= 60
}

void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); }
Expand Down Expand Up @@ -949,6 +993,189 @@ void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) {
}
}

void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
// The LLVM IR for these function was obtained by compiling
//
// For integer ABI:
// __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(a);
// __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(a);
// For floating-point ABI:
// __truncXfYf2__<float, uint32_t, 23, _Float16, uint16_t, 10>(a);
// __extendXfYf2__<_Float16, uint16_t, 10, float, uint32_t, 23>(a);

static const char trunc_body[] = // __truncsfhf2
" %v0 = bitcast float %a0 to i32\n"
" %v1 = and i32 %v0, 2147483647\n"
" %v2 = add nsw i32 %v1, -947912704\n"
" %v3 = add nsw i32 %v1, -1199570944\n"
" %v4 = icmp ult i32 %v2, %v3\n"
" br i1 %v4, label %b1, label %b5\n"
"b1:\n"
" %v5 = lshr i32 %v0, 13\n"
" %v6 = and i32 %v5, 65535\n"
" %v7 = add nuw nsw i32 %v6, -114688\n"
" %v8 = and i32 %v0, 8191\n"
" %v9 = icmp ugt i32 %v8, 4096\n"
" br i1 %v9, label %b2, label %b3\n"
"b2:\n"
" %v10 = add nuw nsw i32 %v6, -114687\n"
" br label %b13\n"
"b3:\n"
" %v11 = icmp eq i32 %v8, 4096\n"
" br i1 %v11, label %b4, label %b13\n"
"b4:\n"
" %v12 = and i32 %v7, 65535\n"
" %v13 = and i32 %v5, 1\n"
" %v14 = add nuw nsw i32 %v12, %v13\n"
" br label %b13\n"
"b5:\n"
" %v15 = icmp ugt i32 %v1, 2139095040\n"
" br i1 %v15, label %b6, label %b7\n"
"b6:\n"
" %v16 = lshr i32 %v0, 13\n"
" %v17 = and i32 %v16, 511\n"
" %v18 = or i32 %v17, 32256\n"
" br label %b13\n"
"b7:\n"
" %v19 = icmp ugt i32 %v1, 1199570943\n"
" br i1 %v19, label %b13, label %b8\n"
"b8:\n"
" %v20 = icmp ult i32 %v1, 754974720\n"
" br i1 %v20, label %b13, label %b9\n"
"b9:\n"
" %v21 = lshr i32 %v1, 23\n"
" %v22 = sub nsw i32 113, %v21\n"
" %v23 = and i32 %v0, 8388607\n"
" %v24 = or i32 %v23, 8388608\n"
" %v25 = add nsw i32 %v21, -81\n"
" %v26 = shl i32 %v24, %v25\n"
" %v27 = icmp ne i32 %v26, 0\n"
" %v28 = lshr i32 %v24, %v22\n"
" %v29 = zext i1 %v27 to i32\n"
" %v30 = lshr i32 %v28, 13\n"
" %v31 = and i32 %v28, 8191\n"
" %v32 = or i32 %v31, %v29\n"
" %v33 = icmp ugt i32 %v32, 4096\n"
" br i1 %v33, label %b10, label %b11\n"
"b10:\n"
" %v34 = add nuw nsw i32 %v30, 1\n"
" br label %b13\n"
"b11:\n"
" %v35 = icmp eq i32 %v32, 4096\n"
" br i1 %v35, label %b12, label %b13\n"
"b12:\n"
" %v36 = and i32 %v30, 1\n"
" %v37 = add nuw nsw i32 %v36, %v30\n"
" br label %b13\n"
"b13:\n"
" %v38 = phi i32 [ %v18, %b6 ], [ %v10, %b2 ], [ %v14, %b4 ], [ %v7, %b3 ],\n"
" [ 31744, %b7 ], [ 0, %b8 ], [ %v34, %b10 ], [ %v37, %b12 ],\n"
" [ %v30, %b11 ]\n"
" %v39 = lshr i32 %v0, 16\n"
" %v40 = and i32 %v39, 32768\n"
" %v41 = or i32 %v38, %v40\n"
" %vlast = trunc i32 %v41 to i16\n";

static const char extend_body[] = // __extendhfsf2
" %v1 = and i16 %vinp, 32767\n"
" %v2 = zext i16 %v1 to i32\n"
" %v3 = add nsw i16 %v1, -1024\n"
" %v4 = icmp ult i16 %v3, 30720\n"
" br i1 %v4, label %b1, label %b2\n"
"b1:\n"
" %v5 = shl nuw nsw i32 %v2, 13\n"
" %v6 = add nuw nsw i32 %v5, 939524096\n"
" br label %b6\n"
"b2:\n"
" %v7 = icmp ugt i16 %v1, 31743\n"
" br i1 %v7, label %b3, label %b4\n"
"b3:\n"
" %v8 = shl nuw nsw i32 %v2, 13\n"
" %v9 = or i32 %v8, 2139095040\n"
" br label %b6\n"
"b4:\n"
" %v10 = icmp eq i16 %v1, 0\n"
" br i1 %v10, label %b6, label %b5\n"
"b5:\n"
" %v11 = icmp ult i16 %v1, 256\n"
" %v12 = lshr i32 %v2, 8\n"
" %v13 = select i1 %v11, i32 %v2, i32 %v12\n"
" %v14 = select i1 %v11, i32 32, i32 24\n"
" %v15 = icmp ult i32 %v13, 16\n"
" %v16 = lshr i32 %v13, 4\n"
" %v17 = add nsw i32 %v14, -4\n"
" %v18 = select i1 %v15, i32 %v13, i32 %v16\n"
" %v19 = select i1 %v15, i32 %v14, i32 %v17\n"
" %v20 = icmp ult i32 %v18, 4\n"
" %v21 = lshr i32 %v18, 2\n"
" %v22 = add nsw i32 %v19, -2\n"
" %v23 = select i1 %v20, i32 %v18, i32 %v21\n"
" %v24 = select i1 %v20, i32 %v19, i32 %v22\n"
" %v25 = icmp ult i32 %v23, 2\n"
" %v26 = sub nsw i32 0, %v23\n"
" %v27 = select i1 %v25, i32 %v26, i32 -2\n"
" %v28 = add nsw i32 %v27, %v24\n"
" %v29 = add nsw i32 %v28, -8\n"
" %v30 = shl i32 %v2, %v29\n"
" %v31 = xor i32 %v30, 8388608\n"
" %v32 = shl i32 %v28, 23\n"
" %v33 = sub i32 1124073472, %v32\n"
" %v34 = or i32 %v31, %v33\n"
" br label %b6\n"
"b6:\n"
" %v35 = phi i32 [ %v6, %b1 ], [ %v9, %b3 ], [ %v34, %b5 ], [ 0, %b4 ]\n"
" %v36 = and i16 %vinp, -32768\n"
" %v37 = zext i16 %v36 to i32\n"
" %v38 = shl nuw i32 %v37, 16\n"
" %v39 = or i32 %v35, %v38\n"
" %v40 = bitcast i32 %v39 to float\n"
" ret float %v40\n"
"}\n";

std::string short_type = use_float16_abi ? "half" : "i16";

std::string short_cast_in, short_cast_out;
if (use_float16_abi) {
short_cast_in = " %vinp = bitcast half %a0 to i16\n";
short_cast_out = " %vres = bitcast i16 %vlast to half\n";
} else {
// No-ops that preserve the i16 values.
short_cast_in = " %vinp = add i16 %a0, 0\n";
short_cast_out = " %vres = add i16 %vlast, 0\n";
}

llvm::Triple triple(llvm_target_->GetTargetTriple());

static const char elf_section_name[] = ".text.tvm.fp16.conv";
std::string section = triple.getObjectFormat() == llvm::Triple::ELF
? std::string("section \"") + elf_section_name + "\" "
: "";

std::string trunc_header = "define weak dso_local " + short_type +
" @__truncsfhf2(float %a0) local_unnamed_addr #0 " + section +
"{\nb0:\n";
std::string trunc_return = " ret " + short_type + " %vres\n}\n";

std::string extend_header = "define weak dso_local float @__extendhfsf2(" + short_type +
" %a0) local_unnamed_addr #0 " + section + "{\nb0:\n";

// truncate = trunc_header + trunc_body + short_cast_out + trunc_return
// extend = extend_header + short_cast_in + extend_body

std::string attributes = "attributes #0 = { nounwind readnone \"target-cpu\"=\"" +
llvm_target_->GetCPU() + "\" \"target-features\"=\"" +
llvm_target_->GetTargetFeatureString() + "\" }\n";

auto data_layout = llvm_target_->GetOrCreateTargetMachine()->createDataLayout();
std::string module_ir = "target triple = \"" + llvm_target_->GetTargetTriple() + "\"\n" +
"target datalayout = \"" + data_layout.getStringRepresentation() +
"\"\n" + trunc_header + trunc_body + short_cast_out + trunc_return +
extend_header + short_cast_in + extend_body + attributes;

auto builtins_module = llvm_target_->GetInstance().ParseIR(module_ir);
link_modules_.push_back(std::move(builtins_module));
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
ICHECK_GE(op->args.size(), 2U);
Expand Down
8 changes: 8 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
* \param func The function to set attributes on.
*/
void SetTargetAttributes(llvm::Function* func);
/*!
* \brief Emit LLVM IR for conversion functions __extendhfsf2 and __truncsfhf2
* into the current llvm::Module.
*
* \param use_float16_abi Whether to use floating-point or integer ABI.
*/
void EmitFloat16ConversionBuiltins(bool use_float16_abi);

/*!
* \brief Get the number of elements in the given vector value.
* \param vec The value, must be of a vector type.
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import ctypes
import json
import math
import numpy as np
import pytest
import re
import sys

import numpy as np
import pytest
import tvm
import tvm.testing
from tvm import te
Expand Down Expand Up @@ -854,7 +854,8 @@ def make_call_extern(caller, callee):
}
mod = tvm.IRModule(functions=functions)
ir_text = tvm.build(mod, None, target="llvm").get_source("ll")
matches = re.findall(r"^define[^@]*@([a-zA-Z_][a-zA-Z0-9_]*)", ir_text, re.MULTILINE)
# Skip functions whose names start with _.
matches = re.findall(r"^define[^@]*@([a-zA-Z][a-zA-Z0-9_]*)", ir_text, re.MULTILINE)
assert matches == sorted(matches)


Expand Down
74 changes: 59 additions & 15 deletions tests/python/unittest/test_target_codegen_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
import platform
import pytest
import re
import textwrap

import tvm
from tvm import te

def test_fp16_to_fp32():
if tvm.target.codegen.llvm_version_major() < 6:
print(
"Skipping due to LLVM version being {} < 6".format(
tvm.target.codegen.llvm_version_major()
)
)
return
llvm_version = tvm.target.codegen.llvm_version_major()
machine = platform.machine()

import platform
if machine not in ["i386", "x86_64", "AMD64", "amd64"]:
pytest.skip(f"Requires x86_64/i386, but machine is {machine}", allow_module_level=True)

machine = platform.machine()
if machine not in ["x86_64", "i386", "AMD64"]:
print("Skipping test because the platform is: {} ".format(machine))
return

@tvm.testing.requires_llvm
@pytest.mark.skipif(llvm_version < 6, reason=f"Requires LLVM 6+, got {llvm_version}")
def test_fp16_to_fp32():
def fp16_to_fp32(target, width, match=None, not_match=None):
elements = 64
n = tvm.runtime.convert(elements)
Expand Down Expand Up @@ -63,5 +61,51 @@ def fp16_to_fp32(target, width, match=None, not_match=None):
fp16_to_fp32("llvm", 9, not_match="vcvtph2ps")


is_32bit = platform.architecture()[0] == "32bit"


@tvm.testing.requires_llvm
@pytest.mark.skipif(is_32bit, reason=f"Fails in CI due to architecture mismatch in JIT")
@pytest.mark.parametrize("feature_string", ["-sse2", "+sse2"])
def test_fp16_fp32_conversions(feature_string):
relay_model = textwrap.dedent(
"""
#[version = "0.0.5"]
def @main(%inp : Tensor[(3), float32], %cst : Tensor[(3), float32]) {
%1 = cast(%inp, dtype="float16");
%2 = cast(%cst, dtype="float16");
%3 = add(%1, %2);
%4 = cast(%3, dtype="float32");
%4
}
"""
)

ir_mod = tvm.parser.fromtext(relay_model)

arch = "i386" if machine == "i386" else "x86_64"
aot_factory = tvm.relay.build(
ir_mod,
params={"cst": np.array([1.0, 2.0, 3.0], dtype="float32")},
target=f"llvm --mtriple={arch} --mattr={feature_string}",
executor=tvm.relay.backend.Executor(
"aot", {"interface-api": "packed", "unpacked-api": False}
),
)

mod_name = aot_factory["list_module_names"]()[0]
executor = aot_factory[mod_name]
mod = executor(tvm.cpu(0))

inp = tvm.nd.array(np.array([1.1, 2.1, 3.1], dtype="float32"), device=tvm.cpu(0))

mod.get_function("set_input")(0, inp)
mod.get_function("run")()
out = mod.get_function("get_output")(0)

expected = np.array([2.1, 4.1, 6.1], dtype="float32")
np.testing.assert_allclose(out.asnumpy(), expected, rtol=1e-3)


if __name__ == "__main__":
test_fp16_to_fp32()