Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ pub fn atan2_partials<T: Float>(a: T, b: T) -> (T, T) {
if h == T::zero() {
(T::zero(), T::zero())
} else {
(b / h / h, T::zero() - a / h / h)
// Unary `-x` (via `Float: Neg<Output = Self>`) preserves the IEEE
// signed-zero invariant `-(+0.0) = -0.0`. `T::zero() - x` would
// flatten to `+0.0` at `x = +0.0` under round-to-nearest, silently
// changing sign-bit semantics observable by downstream `copysign`
// / `is_sign_negative` consumers.
(b / h / h, -a / h / h)
}
}

Expand Down
29 changes: 22 additions & 7 deletions src/tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,28 @@ impl<F: Float> Tape<F> {
/// are unaffected.
///
/// Callers on x86 where subnormal adjoints are expected can opt into
/// FTZ by setting `MXCSR` themselves (`_mm_setcsr(0x9FC0)` via
/// `core::arch::x86_64::_mm_setcsr`) around the reverse-sweep call.
/// Doing so in the library would change numerical semantics for
/// callers who depend on subnormal precision, so the choice is
/// deferred. ARM64 always flushes subnormals by default (FPCR.FZ=1
/// in AArch32 compatibility, FPCR.FZ16 et al. on AArch64), so this
/// warning is x86-specific.
/// FTZ by setting the MXCSR bit themselves. The correct read-modify-
/// write idiom (so the caller's existing rounding mode and exception
/// masks survive) is:
///
/// ```ignore
/// use core::arch::x86_64::{_mm_getcsr, _mm_setcsr};
/// // MXCSR bit 15 = FTZ. Bit 6 = DAZ (input denormals flushed to
/// // zero) is *independent*; enable it too only if you also want
/// // subnormal inputs treated as zero.
/// let saved = unsafe { _mm_getcsr() };
/// unsafe { _mm_setcsr(saved | (1 << 15)) };
/// // ... tape.reverse(...) ...
/// unsafe { _mm_setcsr(saved) }; // restore
/// ```
///
/// Doing this globally in the library would change numerical
/// semantics for callers who depend on subnormal precision or have
/// set a non-default rounding mode (e.g. interval arithmetic with
/// `FE_DOWNWARD`), so the choice is deferred.
///
/// ARM64 flushes subnormals by default (FPCR.FZ=1 on AArch32 and
/// FPCR.FZ et al. on AArch64), so this warning is x86-specific.
#[must_use]
pub fn reverse(&self, seed_index: u32) -> Vec<F> {
let mut adjoints = vec![F::zero(); self.num_variables as usize];
Expand Down
136 changes: 133 additions & 3 deletions tests/gpu_cpu_parity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,39 @@ fn build_tanh() -> (BytecodeTape<f64>, f64) {
fn build_asinh() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].asinh(), &[1.0])
}
fn build_acosh() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].acosh(), &[2.0])
}
fn build_atanh() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].atanh(), &[0.5])
}
fn build_asin() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].asin(), &[0.5])
}
fn build_acos() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].acos(), &[0.5])
}
fn build_exp2() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].exp2(), &[1.0])
}
fn build_log2() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].log2(), &[1.0])
}
fn build_log10() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].log10(), &[1.0])
}
fn build_rem() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0] % v[1], &[5.0, 2.0])
}
fn build_powi() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].powi(3), &[2.0])
}
fn build_powf() -> (BytecodeTape<f64>, f64) {
record(
|v: &[BReverse<f64>]| v[0].powf(BReverse::constant(2.5)),
&[2.0],
)
}
fn build_hypot() -> (BytecodeTape<f64>, f64) {
record(|v: &[BReverse<f64>]| v[0].hypot(v[1]), &[3.0, 4.0])
}
Expand Down Expand Up @@ -339,6 +372,91 @@ const PARITY_CASES: &[ParityCase] = &[
f32_ulp: 16,
f64_ulp: 16,
},
ParityCase {
name: "acosh",
n_inputs: 1,
build: build_acosh,
// acosh domain is a >= 1.
points: &[&[1.5], &[2.0], &[10.0]],
f32_ulp: 16,
f64_ulp: 16,
},
ParityCase {
name: "atanh",
n_inputs: 1,
build: build_atanh,
// atanh domain is |a| < 1.
points: &[&[0.0], &[0.25], &[-0.5], &[0.9]],
f32_ulp: 16,
f64_ulp: 16,
},
ParityCase {
name: "asin",
n_inputs: 1,
build: build_asin,
points: &[&[0.0], &[0.5], &[-0.25]],
f32_ulp: 16,
f64_ulp: 16,
},
ParityCase {
name: "acos",
n_inputs: 1,
build: build_acos,
points: &[&[0.0], &[0.5], &[-0.25]],
f32_ulp: 16,
f64_ulp: 16,
},
// Exp/Log extras
ParityCase {
name: "exp2",
n_inputs: 1,
build: build_exp2,
points: &[&[0.0], &[1.0], &[-1.0], &[3.0]],
f32_ulp: 8,
f64_ulp: 8,
},
ParityCase {
name: "log2",
n_inputs: 1,
build: build_log2,
points: &[&[1.0], &[2.0], &[8.0]],
f32_ulp: 8,
f64_ulp: 8,
},
ParityCase {
name: "log10",
n_inputs: 1,
build: build_log10,
points: &[&[1.0], &[10.0], &[100.0]],
f32_ulp: 8,
f64_ulp: 8,
},
// Powers — fragile ops Phase 7 specifically patched.
ParityCase {
name: "powi",
n_inputs: 1,
build: build_powi,
points: &[&[2.0], &[-3.0], &[0.5]],
f32_ulp: 8,
f64_ulp: 8,
},
ParityCase {
name: "powf",
n_inputs: 1,
build: build_powf,
points: &[&[2.0], &[0.5], &[10.0]],
f32_ulp: 32,
f64_ulp: 16,
},
// Remainder.
ParityCase {
name: "rem",
n_inputs: 2,
build: build_rem,
points: &[&[5.0, 2.0], &[7.5, 2.5], &[-3.0, 2.0]],
f32_ulp: 4,
f64_ulp: 4,
},
// Multi-arg
ParityCase {
name: "hypot",
Expand Down Expand Up @@ -439,7 +557,13 @@ fn ulp_diff_f64(a: f64, b: f64) -> u64 {
fn wgpu_parity_all_cases() {
let ctx = match WgpuContext::new() {
Some(c) => c,
None => return,
None => {
// Silent returns on no-GPU machines would pass the assertion
// without running any case. Surface the skip so `cargo test
// -- --nocapture` makes it visible.
eprintln!("SKIP: no wgpu adapter; parity test not executed");
return;
}
};
let mut failures = Vec::new();
for case in PARITY_CASES {
Expand Down Expand Up @@ -502,7 +626,10 @@ fn wgpu_parity_all_cases() {
fn cuda_f32_parity_all_cases() {
let ctx = match CudaContext::new() {
Some(c) => c,
None => return,
None => {
eprintln!("SKIP: no CUDA device; parity test not executed");
return;
}
};
let mut failures = Vec::new();
for case in PARITY_CASES {
Expand Down Expand Up @@ -565,7 +692,10 @@ fn cuda_f32_parity_all_cases() {
fn cuda_f64_parity_all_cases() {
let ctx = match CudaContext::new() {
Some(c) => c,
None => return,
None => {
eprintln!("SKIP: no CUDA device; parity test not executed");
return;
}
};
let mut failures = Vec::new();
for case in PARITY_CASES {
Expand Down
Loading