diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 4ce7ce9f2291..a8a3d911ca8e 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -167,9 +167,15 @@ TVM_REGISTER_OP("tir.sinh") TVM_REGISTER_OP("tir.asin") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; + using namespace intrin; const tir::CallNode* call = e.as(); ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; + + PrimExpr threshold = make_const(x.dtype(), 0.5); + PrimExpr abs_x = tvm::abs(x); + PrimExpr use_lib = abs_x >= threshold; + PrimExpr x2 = x * x; PrimExpr term1 = x; PrimExpr term3 = term1 * x2 / make_const(x.dtype(), 6); @@ -178,25 +184,43 @@ TVM_REGISTER_OP("tir.asin") PrimExpr term9 = term7 * x2 * make_const(x.dtype(), 1225) / make_const(x.dtype(), 3456); PrimExpr term11 = term9 * x2 * make_const(x.dtype(), 3969) / make_const(x.dtype(), 28160); PrimExpr series = term1 + term3 + term5 + term7 + term9 + term11; - /* --- domain limit check --- */ + + PrimExpr lib_result = + ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); + PrimExpr lower = make_const(x.dtype(), -1.0); PrimExpr upper = make_const(x.dtype(), 1.0); PrimExpr out_range = tir::Or(x upper); - // Use a quiet NaN constant PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); - // select: if out of [-1,1] → NaN, else → series - return tir::Select(out_range, nan_const, series); + + return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, series)); }); TVM_REGISTER_OP("tir.acos") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; + using namespace intrin; const tir::CallNode* call = e.as(); ICHECK(call != nullptr) << "Invalid call node in acos legalization"; const PrimExpr& x = call->args[0]; + + PrimExpr threshold = make_const(x.dtype(), 0.5); + PrimExpr abs_x = tvm::abs(x); + PrimExpr use_lib = abs_x >= threshold; + PrimExpr half_pi = make_const(x.dtype(), M_PI / 2); PrimExpr asin_x = asin(x); - return half_pi - asin_x; + PrimExpr formula_result = half_pi - asin_x; + + PrimExpr lib_result = + ::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e); + + PrimExpr lower = make_const(x.dtype(), -1.0); + PrimExpr upper = make_const(x.dtype(), 1.0); + PrimExpr out_range = tir::Or(x upper); + PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); + + return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, formula_result)); }); TVM_REGISTER_OP("tir.atan") diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tir-base/test_tir_intrin.py index 8dabdbb344f3..1e8c88e08e65 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tir-base/test_tir_intrin.py @@ -135,6 +135,58 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): run_test(*func, atol, rtol) +def test_asin_acos_boundary_values(): + """Test asin and acos with boundary values and threshold switching.""" + test_funcs = [ + (tvm.tir.asin, lambda x: np.arcsin(x)), + (tvm.tir.acos, lambda x: np.arccos(x)), + ] + + def run_test(tvm_intrin, np_func): + m = te.var("m") + A = te.placeholder((m,), name="A") + B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B") + + mod = te.create_prim_func([A, B]) + sch = tir.Schedule(mod) + func = tvm.compile(sch.mod, target="llvm") + + dev = tvm.cpu(0) + + # Test boundary values: ±1.0 (should use system library) + boundary_values = np.array([1.0, -1.0], dtype=np.float32) + a1 = tvm.runtime.tensor(boundary_values, dev) + b1 = tvm.runtime.tensor(np.empty_like(boundary_values), dev) + func(a1, b1) + tvm.testing.assert_allclose(b1.numpy(), np_func(boundary_values), atol=1e-5, rtol=1e-5) + + # Test values at threshold: ±0.5 (should use system library) + threshold_values = np.array([0.5, -0.5], dtype=np.float32) + a2 = tvm.runtime.tensor(threshold_values, dev) + b2 = tvm.runtime.tensor(np.empty_like(threshold_values), dev) + func(a2, b2) + tvm.testing.assert_allclose(b2.numpy(), np_func(threshold_values), atol=1e-4, rtol=1e-4) + + # Test values just below threshold: ±0.49 (should use Taylor series) + below_threshold_values = np.array([0.49, -0.49, 0.3, -0.3, 0.0], dtype=np.float32) + a3 = tvm.runtime.tensor(below_threshold_values, dev) + b3 = tvm.runtime.tensor(np.empty_like(below_threshold_values), dev) + func(a3, b3) + tvm.testing.assert_allclose( + b3.numpy(), np_func(below_threshold_values), atol=1e-3, rtol=1e-3 + ) + + # Test out-of-domain values: should return NaN + out_of_domain = np.array([1.1, -1.1, 2.0, -2.0], dtype=np.float32) + a4 = tvm.runtime.tensor(out_of_domain, dev) + b4 = tvm.runtime.tensor(np.empty_like(out_of_domain), dev) + func(a4, b4) + assert np.all(np.isnan(b4.numpy())), "Out-of-domain inputs should return NaN" + + for func in test_funcs: + run_test(*func) + + def test_binary_intrin(): test_funcs = [ (tvm.tir.atan2, lambda x1, x2: np.arctan2(x1, x2)), @@ -315,6 +367,7 @@ def test_fma(): test_nearbyint() test_unary_intrin() test_round_intrinsics_on_int() + test_asin_acos_boundary_values() test_binary_intrin() test_ldexp() test_clz()