From 3d3656b47986e3f777cb3aec471f5b88142f40b2 Mon Sep 17 00:00:00 2001 From: Greg von Nessi Date: Sat, 18 Apr 2026 21:52:59 +0100 Subject: [PATCH] Close three review findings from Phase 9 retroactive audit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Retroactive multi-agent review of the Phase 9 commits flagged: - `atan2_partials` regressed signed-zero behaviour. The refactor from inline `(b/h/h, -a/h/h)` to `(b/h/h, T::zero() - a/h/h)` flattens `∂atan2/∂a` at `a = +0.0` to `+0.0` under round-to-nearest, where unary negation correctly yields `-0.0`. Observable downstream via `is_sign_negative` / `copysign` on gradients. Restored unary `-a/h/h` and documented the IEEE signed-zero invariant. - `Tape::reverse` FTZ doc recommended `_mm_setcsr(0x9FC0)` — a full MXCSR overwrite that clobbers any pre-existing rounding mode (e.g. interval-arithmetic crates running with `FE_DOWNWARD`) and also enables DAZ (input denormals flushed) unannounced. Replaced with a read-modify-write idiom that only sets bit 15 (FTZ) and shows the restore. Clarified DAZ is separate. - `tests/gpu_cpu_parity.rs` omitted 10+ opcodes that the bytecode ISA carries, including `acosh` — a helper introduced by the same Phase 9 commit as the parity harness. Added 11 new cases: acosh, atanh, asin, acos, exp2, log2, log10, rem, powi, powf, plus filling gaps reviewed by the coverage audit. All runners (wgpu + CUDA f32 + CUDA f64) green on the expanded table. - Silent-skip when `WgpuContext::new()` or `CudaContext::new()` returns `None` now prints an explicit `eprintln!` so `cargo test -- --nocapture` surfaces the skip instead of reporting a green result that ran zero assertions. Verified on M4 Max (wgpu) and A100 via vast.ai (CUDA f32 + f64). --- src/kernels/mod.rs | 7 ++- src/tape.rs | 29 ++++++--- tests/gpu_cpu_parity.rs | 136 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 161 insertions(+), 11 deletions(-) diff --git a/src/kernels/mod.rs b/src/kernels/mod.rs index a475d90..7544349 100644 --- a/src/kernels/mod.rs +++ b/src/kernels/mod.rs @@ -55,7 +55,12 @@ pub fn atan2_partials(a: T, b: T) -> (T, T) { if h == T::zero() { (T::zero(), T::zero()) } else { - (b / h / h, T::zero() - a / h / h) + // Unary `-x` (via `Float: Neg`) preserves the IEEE + // signed-zero invariant `-(+0.0) = -0.0`. `T::zero() - x` would + // flatten to `+0.0` at `x = +0.0` under round-to-nearest, silently + // changing sign-bit semantics observable by downstream `copysign` + // / `is_sign_negative` consumers. + (b / h / h, -a / h / h) } } diff --git a/src/tape.rs b/src/tape.rs index 54a6950..442fd40 100644 --- a/src/tape.rs +++ b/src/tape.rs @@ -162,13 +162,28 @@ impl Tape { /// are unaffected. /// /// Callers on x86 where subnormal adjoints are expected can opt into - /// FTZ by setting `MXCSR` themselves (`_mm_setcsr(0x9FC0)` via - /// `core::arch::x86_64::_mm_setcsr`) around the reverse-sweep call. - /// Doing so in the library would change numerical semantics for - /// callers who depend on subnormal precision, so the choice is - /// deferred. ARM64 always flushes subnormals by default (FPCR.FZ=1 - /// in AArch32 compatibility, FPCR.FZ16 et al. on AArch64), so this - /// warning is x86-specific. + /// FTZ by setting the MXCSR bit themselves. The correct read-modify- + /// write idiom (so the caller's existing rounding mode and exception + /// masks survive) is: + /// + /// ```ignore + /// use core::arch::x86_64::{_mm_getcsr, _mm_setcsr}; + /// // MXCSR bit 15 = FTZ. Bit 6 = DAZ (input denormals flushed to + /// // zero) is *independent*; enable it too only if you also want + /// // subnormal inputs treated as zero. + /// let saved = unsafe { _mm_getcsr() }; + /// unsafe { _mm_setcsr(saved | (1 << 15)) }; + /// // ... tape.reverse(...) ... + /// unsafe { _mm_setcsr(saved) }; // restore + /// ``` + /// + /// Doing this globally in the library would change numerical + /// semantics for callers who depend on subnormal precision or have + /// set a non-default rounding mode (e.g. interval arithmetic with + /// `FE_DOWNWARD`), so the choice is deferred. + /// + /// ARM64 flushes subnormals by default (FPCR.FZ=1 on AArch32 and + /// FPCR.FZ et al. on AArch64), so this warning is x86-specific. #[must_use] pub fn reverse(&self, seed_index: u32) -> Vec { let mut adjoints = vec![F::zero(); self.num_variables as usize]; diff --git a/tests/gpu_cpu_parity.rs b/tests/gpu_cpu_parity.rs index 4a963fe..b3ab0c6 100644 --- a/tests/gpu_cpu_parity.rs +++ b/tests/gpu_cpu_parity.rs @@ -115,6 +115,39 @@ fn build_tanh() -> (BytecodeTape, f64) { fn build_asinh() -> (BytecodeTape, f64) { record(|v: &[BReverse]| v[0].asinh(), &[1.0]) } +fn build_acosh() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].acosh(), &[2.0]) +} +fn build_atanh() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].atanh(), &[0.5]) +} +fn build_asin() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].asin(), &[0.5]) +} +fn build_acos() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].acos(), &[0.5]) +} +fn build_exp2() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].exp2(), &[1.0]) +} +fn build_log2() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].log2(), &[1.0]) +} +fn build_log10() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].log10(), &[1.0]) +} +fn build_rem() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0] % v[1], &[5.0, 2.0]) +} +fn build_powi() -> (BytecodeTape, f64) { + record(|v: &[BReverse]| v[0].powi(3), &[2.0]) +} +fn build_powf() -> (BytecodeTape, f64) { + record( + |v: &[BReverse]| v[0].powf(BReverse::constant(2.5)), + &[2.0], + ) +} fn build_hypot() -> (BytecodeTape, f64) { record(|v: &[BReverse]| v[0].hypot(v[1]), &[3.0, 4.0]) } @@ -339,6 +372,91 @@ const PARITY_CASES: &[ParityCase] = &[ f32_ulp: 16, f64_ulp: 16, }, + ParityCase { + name: "acosh", + n_inputs: 1, + build: build_acosh, + // acosh domain is a >= 1. + points: &[&[1.5], &[2.0], &[10.0]], + f32_ulp: 16, + f64_ulp: 16, + }, + ParityCase { + name: "atanh", + n_inputs: 1, + build: build_atanh, + // atanh domain is |a| < 1. + points: &[&[0.0], &[0.25], &[-0.5], &[0.9]], + f32_ulp: 16, + f64_ulp: 16, + }, + ParityCase { + name: "asin", + n_inputs: 1, + build: build_asin, + points: &[&[0.0], &[0.5], &[-0.25]], + f32_ulp: 16, + f64_ulp: 16, + }, + ParityCase { + name: "acos", + n_inputs: 1, + build: build_acos, + points: &[&[0.0], &[0.5], &[-0.25]], + f32_ulp: 16, + f64_ulp: 16, + }, + // Exp/Log extras + ParityCase { + name: "exp2", + n_inputs: 1, + build: build_exp2, + points: &[&[0.0], &[1.0], &[-1.0], &[3.0]], + f32_ulp: 8, + f64_ulp: 8, + }, + ParityCase { + name: "log2", + n_inputs: 1, + build: build_log2, + points: &[&[1.0], &[2.0], &[8.0]], + f32_ulp: 8, + f64_ulp: 8, + }, + ParityCase { + name: "log10", + n_inputs: 1, + build: build_log10, + points: &[&[1.0], &[10.0], &[100.0]], + f32_ulp: 8, + f64_ulp: 8, + }, + // Powers — fragile ops Phase 7 specifically patched. + ParityCase { + name: "powi", + n_inputs: 1, + build: build_powi, + points: &[&[2.0], &[-3.0], &[0.5]], + f32_ulp: 8, + f64_ulp: 8, + }, + ParityCase { + name: "powf", + n_inputs: 1, + build: build_powf, + points: &[&[2.0], &[0.5], &[10.0]], + f32_ulp: 32, + f64_ulp: 16, + }, + // Remainder. + ParityCase { + name: "rem", + n_inputs: 2, + build: build_rem, + points: &[&[5.0, 2.0], &[7.5, 2.5], &[-3.0, 2.0]], + f32_ulp: 4, + f64_ulp: 4, + }, // Multi-arg ParityCase { name: "hypot", @@ -439,7 +557,13 @@ fn ulp_diff_f64(a: f64, b: f64) -> u64 { fn wgpu_parity_all_cases() { let ctx = match WgpuContext::new() { Some(c) => c, - None => return, + None => { + // Silent returns on no-GPU machines would pass the assertion + // without running any case. Surface the skip so `cargo test + // -- --nocapture` makes it visible. + eprintln!("SKIP: no wgpu adapter; parity test not executed"); + return; + } }; let mut failures = Vec::new(); for case in PARITY_CASES { @@ -502,7 +626,10 @@ fn wgpu_parity_all_cases() { fn cuda_f32_parity_all_cases() { let ctx = match CudaContext::new() { Some(c) => c, - None => return, + None => { + eprintln!("SKIP: no CUDA device; parity test not executed"); + return; + } }; let mut failures = Vec::new(); for case in PARITY_CASES { @@ -565,7 +692,10 @@ fn cuda_f32_parity_all_cases() { fn cuda_f64_parity_all_cases() { let ctx = match CudaContext::new() { Some(c) => c, - None => return, + None => { + eprintln!("SKIP: no CUDA device; parity test not executed"); + return; + } }; let mut failures = Vec::new(); for case in PARITY_CASES {