Skip to content
34 changes: 29 additions & 5 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,15 @@ TVM_REGISTER_OP("tir.sinh")
TVM_REGISTER_OP("tir.asin")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using namespace intrin;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
const PrimExpr& x = call->args[0];

PrimExpr threshold = make_const(x.dtype(), 0.5);
PrimExpr abs_x = tvm::abs(x);
PrimExpr use_lib = abs_x >= threshold;

PrimExpr x2 = x * x;
PrimExpr term1 = x;
PrimExpr term3 = term1 * x2 / make_const(x.dtype(), 6);
Expand All @@ -178,25 +184,43 @@ TVM_REGISTER_OP("tir.asin")
PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456);
PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160);
PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11;
/* --- domain limit check --- */

PrimExpr lib_result =
::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e);

PrimExpr lower = make_const(x.dtype(), -1.0);
PrimExpr upper = make_const(x.dtype(), 1.0);
PrimExpr out_range = tir::Or(x<lower, x> upper);
// Use a quiet NaN constant
PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits<double>::quiet_NaN());
// select: if out of [-1,1] → NaN, else → series
return tir::Select(out_range, nan_const, series);

return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, series));
});

TVM_REGISTER_OP("tir.acos")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tir::make_const;
using namespace intrin;
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr) << "Invalid call node in acos legalization";
const PrimExpr& x = call->args[0];

PrimExpr threshold = make_const(x.dtype(), 0.5);
PrimExpr abs_x = tvm::abs(x);
PrimExpr use_lib = abs_x >= threshold;

PrimExpr half_pi = make_const(x.dtype(), M_PI / 2);
PrimExpr asin_x = asin(x);
return half_pi - asin_x;
PrimExpr formula_result = half_pi - asin_x;

PrimExpr lib_result =
::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e);

PrimExpr lower = make_const(x.dtype(), -1.0);
PrimExpr upper = make_const(x.dtype(), 1.0);
PrimExpr out_range = tir::Or(x<lower, x> upper);
PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits<double>::quiet_NaN());

return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, formula_result));
});

TVM_REGISTER_OP("tir.atan")
Expand Down
53 changes: 53 additions & 0 deletions tests/python/tir-base/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,58 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):
run_test(*func, atol, rtol)


def test_asin_acos_boundary_values():
"""Test asin and acos with boundary values and threshold switching."""
test_funcs = [
(tvm.tir.asin, lambda x: np.arcsin(x)),
(tvm.tir.acos, lambda x: np.arccos(x)),
]

def run_test(tvm_intrin, np_func):
m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B")

mod = te.create_prim_func([A, B])
sch = tir.Schedule(mod)
func = tvm.compile(sch.mod, target="llvm")

dev = tvm.cpu(0)

# Test boundary values: ±1.0 (should use system library)
boundary_values = np.array([1.0, -1.0], dtype=np.float32)
a1 = tvm.runtime.tensor(boundary_values, dev)
b1 = tvm.runtime.tensor(np.empty_like(boundary_values), dev)
func(a1, b1)
tvm.testing.assert_allclose(b1.numpy(), np_func(boundary_values), atol=1e-5, rtol=1e-5)

# Test values at threshold: ±0.5 (should use system library)
threshold_values = np.array([0.5, -0.5], dtype=np.float32)
a2 = tvm.runtime.tensor(threshold_values, dev)
b2 = tvm.runtime.tensor(np.empty_like(threshold_values), dev)
func(a2, b2)
tvm.testing.assert_allclose(b2.numpy(), np_func(threshold_values), atol=1e-4, rtol=1e-4)

# Test values just below threshold: ±0.49 (should use Taylor series)
below_threshold_values = np.array([0.49, -0.49, 0.3, -0.3, 0.0], dtype=np.float32)
a3 = tvm.runtime.tensor(below_threshold_values, dev)
b3 = tvm.runtime.tensor(np.empty_like(below_threshold_values), dev)
func(a3, b3)
tvm.testing.assert_allclose(
b3.numpy(), np_func(below_threshold_values), atol=1e-3, rtol=1e-3
)

# Test out-of-domain values: should return NaN
out_of_domain = np.array([1.1, -1.1, 2.0, -2.0], dtype=np.float32)
a4 = tvm.runtime.tensor(out_of_domain, dev)
b4 = tvm.runtime.tensor(np.empty_like(out_of_domain), dev)
func(a4, b4)
assert np.all(np.isnan(b4.numpy())), "Out-of-domain inputs should return NaN"

for func in test_funcs:
run_test(*func)


def test_binary_intrin():
test_funcs = [
(tvm.tir.atan2, lambda x1, x2: np.arctan2(x1, x2)),
Expand Down Expand Up @@ -315,6 +367,7 @@ def test_fma():
test_nearbyint()
test_unary_intrin()
test_round_intrinsics_on_int()
test_asin_acos_boundary_values()
test_binary_intrin()
test_ldexp()
test_clz()
Expand Down