diff --git a/src/kernels/mod.rs b/src/kernels/mod.rs index a475d90..7544349 100644 --- a/src/kernels/mod.rs +++ b/src/kernels/mod.rs @@ -55,7 +55,12 @@ pub fn atan2_partials(a: T, b: T) -> (T, T) { if h == T::zero() { (T::zero(), T::zero()) } else { - (b / h / h, T::zero() - a / h / h) + // Unary `-x` (via `Float: Neg`) preserves the IEEE + // signed-zero invariant `-(+0.0) = -0.0`. `T::zero() - x` would + // flatten to `+0.0` at `x = +0.0` under round-to-nearest, silently + // changing sign-bit semantics observable by downstream `copysign` + // / `is_sign_negative` consumers. + (b / h / h, -a / h / h) } } diff --git a/src/tape.rs b/src/tape.rs index 54a6950..442fd40 100644 --- a/src/tape.rs +++ b/src/tape.rs @@ -162,13 +162,28 @@ impl Tape { /// are unaffected. /// /// Callers on x86 where subnormal adjoints are expected can opt into - /// FTZ by setting `MXCSR` themselves (`_mm_setcsr(0x9FC0)` via - /// `core::arch::x86_64::_mm_setcsr`) around the reverse-sweep call. - /// Doing so in the library would change numerical semantics for - /// callers who depend on subnormal precision, so the choice is - /// deferred. ARM64 always flushes subnormals by default (FPCR.FZ=1 - /// in AArch32 compatibility, FPCR.FZ16 et al. on AArch64), so this - /// warning is x86-specific. + /// FTZ by setting the MXCSR bit themselves. The correct read-modify- + /// write idiom (so the caller's existing rounding mode and exception + /// masks survive) is: + /// + /// ```ignore + /// use core::arch::x86_64::{_mm_getcsr, _mm_setcsr}; + /// // MXCSR bit 15 = FTZ. Bit 6 = DAZ (input denormals flushed to + /// // zero) is *independent*; enable it too only if you also want + /// // subnormal inputs treated as zero. + /// let saved = unsafe { _mm_getcsr() }; + /// unsafe { _mm_setcsr(saved | (1 << 15)) }; + /// // ... tape.reverse(...) ... + /// unsafe { _mm_setcsr(saved) }; // restore + /// ``` + /// + /// Doing this globally in the library would change numerical + /// semantics for callers who depend on subnormal precision or have + /// set a non-default rounding mode (e.g. interval arithmetic with + /// `FE_DOWNWARD`), so the choice is deferred. + /// + /// ARM64 flushes subnormals by default (FPCR.FZ=1 on AArch32 and + /// FPCR.FZ et al. on AArch64), so this warning is x86-specific. #[must_use] pub fn reverse(&self, seed_index: u32) -> Vec { let mut adjoints = vec![F::zero(); self.num_variables as usize]; diff --git a/tests/gpu_cpu_parity.rs b/tests/gpu_cpu_parity.rs index 4a963fe..b3ab0c6 100644 --- a/tests/gpu_cpu_parity.rs +++ b/tests/gpu_cpu_parity.rs @@ -115,6 +115,39 @@ fn build_tanh() -> (BytecodeTape, f64) { fn build_asinh() -> (BytecodeTape, f64) { record(|v: &[BReverse]| v[0].asinh(), &[1.0]) } +fn build_acosh() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].acosh(), &[2.0]) +} +fn build_atanh() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].atanh(), &[0.5]) +} +fn build_asin() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].asin(), &[0.5]) +} +fn build_acos() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].acos(), &[0.5]) +} +fn build_exp2() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].exp2(), &[1.0]) +} +fn build_log2() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].log2(), &[1.0]) +} +fn build_log10() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].log10(), &[1.0]) +} +fn build_rem() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0] % v[1], &[5.0, 2.0]) +} +fn build_powi() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].powi(3), &[2.0]) +} +fn build_powf() -> (BytecodeTape, f64) { + record( + |v: &[BReverse]| v[0].powf(BReverse::constant(2.5)), + &[2.0], + ) +} fn build_hypot() -> (BytecodeTape, f64) { record(|v: &[BReverse]| v[0].hypot(v[1]), &[3.0, 4.0]) } @@ -339,6 +372,91 @@ const PARITY_CASES: &[ParityCase] = &[ f32_ulp: 16, f64_ulp: 16, }, + ParityCase { + name: "acosh", + n_inputs: 1, + build: build_acosh, + // acosh domain is a >= 1. + points: &[&[1.5], &[2.0], &[10.0]], + f32_ulp: 16, + f64_ulp: 16, + }, + ParityCase { + name: "atanh", + n_inputs: 1, + build: build_atanh, + // atanh domain is |a| < 1. + points: &[&[0.0], &[0.25], &[-0.5], &[0.9]], + f32_ulp: 16, + f64_ulp: 16, + }, + ParityCase { + name: "asin", + n_inputs: 1, + build: build_asin, + points: &[&[0.0], &[0.5], &[-0.25]], + f32_ulp: 16, + f64_ulp: 16, + }, + ParityCase { + name: "acos", + n_inputs: 1, + build: build_acos, + points: &[&[0.0], &[0.5], &[-0.25]], + f32_ulp: 16, + f64_ulp: 16, + }, + // Exp/Log extras + ParityCase { + name: "exp2", + n_inputs: 1, + build: build_exp2, + points: &[&[0.0], &[1.0], &[-1.0], &[3.0]], + f32_ulp: 8, + f64_ulp: 8, + }, + ParityCase { + name: "log2", + n_inputs: 1, + build: build_log2, + points: &[&[1.0], &[2.0], &[8.0]], + f32_ulp: 8, + f64_ulp: 8, + }, + ParityCase { + name: "log10", + n_inputs: 1, + build: build_log10, + points: &[&[1.0], &[10.0], &[100.0]], + f32_ulp: 8, + f64_ulp: 8, + }, + // Powers — fragile ops Phase 7 specifically patched. + ParityCase { + name: "powi", + n_inputs: 1, + build: build_powi, + points: &[&[2.0], &[-3.0], &[0.5]], + f32_ulp: 8, + f64_ulp: 8, + }, + ParityCase { + name: "powf", + n_inputs: 1, + build: build_powf, + points: &[&[2.0], &[0.5], &[10.0]], + f32_ulp: 32, + f64_ulp: 16, + }, + // Remainder. + ParityCase { + name: "rem", + n_inputs: 2, + build: build_rem, + points: &[&[5.0, 2.0], &[7.5, 2.5], &[-3.0, 2.0]], + f32_ulp: 4, + f64_ulp: 4, + }, // Multi-arg ParityCase { name: "hypot", @@ -439,7 +557,13 @@ fn ulp_diff_f64(a: f64, b: f64) -> u64 { fn wgpu_parity_all_cases() { let ctx = match WgpuContext::new() { Some(c) => c, - None => return, + None => { + // Silent returns on no-GPU machines would pass the assertion + // without running any case. Surface the skip so `cargo test + // -- --nocapture` makes it visible. + eprintln!("SKIP: no wgpu adapter; parity test not executed"); + return; + } }; let mut failures = Vec::new(); for case in PARITY_CASES { @@ -502,7 +626,10 @@ fn wgpu_parity_all_cases() { fn cuda_f32_parity_all_cases() { let ctx = match CudaContext::new() { Some(c) => c, - None => return, + None => { + eprintln!("SKIP: no CUDA device; parity test not executed"); + return; + } }; let mut failures = Vec::new(); for case in PARITY_CASES { @@ -565,7 +692,10 @@ fn cuda_f32_parity_all_cases() { fn cuda_f64_parity_all_cases() { let ctx = match CudaContext::new() { Some(c) => c, - None => return, + None => { + eprintln!("SKIP: no CUDA device; parity test not executed"); + return; + } }; let mut failures = Vec::new(); for case in PARITY_CASES {