From 5de0b8831bf77b56c396fa5ba40a994e71939d38 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:19:59 +0200 Subject: [PATCH] perf: add a stack-allocated to_base_be --- proptest-regressions/fmt.txt | 7 ++ src/base_convert.rs | 147 ++++++++++++++++++++++++++++++++++- src/fmt.rs | 6 +- 3 files changed, 154 insertions(+), 6 deletions(-) create mode 100644 proptest-regressions/fmt.txt diff --git a/proptest-regressions/fmt.txt b/proptest-regressions/fmt.txt new file mode 100644 index 00000000..e26f28c9 --- /dev/null +++ b/proptest-regressions/fmt.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc d13527baca051245a2b2a1f5b9fc131f9c7c1bbe4f115bf1fd90e8b9bf86de09 # shrinks to value = 1329227995784915872903807060280344576 diff --git a/src/base_convert.rs b/src/base_convert.rs index 9f70a94f..4f0ceab9 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -2,7 +2,7 @@ use crate::{ algorithms::{addmul_nx1, mul_nx1}, Uint, }; -use core::fmt; +use core::{fmt, iter::FusedIterator, mem::MaybeUninit}; /// Error for [`from_base_le`][Uint::from_base_le] and /// [`from_base_be`][Uint::from_base_be]. @@ -49,6 +49,7 @@ impl Uint { /// /// Panics if the base is less than 2. #[inline] + #[track_caller] pub fn to_base_le(&self, base: u64) -> impl Iterator { SpigotLittle::new(self.limbs, base) } @@ -60,10 +61,26 @@ impl Uint { /// power of `10` that still fits `u64`. This way much fewer iterations /// are required to extract all the digits. /// + /// Use [`to_base_be_2`](Self::to_base_be_2) to extract the maximum number + /// of digits at once more efficiently. + /// /// # Panics /// /// Panics if the base is less than 2. + /// + /// # Examples + /// + /// ``` + /// let n = ruint::aliases::U64::from(1234); + /// assert_eq!(n.to_base_be(10).collect::>(), [1, 2, 3, 4]); + /// assert_eq!(n.to_base_be(1000000).collect::>(), [1234]); + /// + /// // `to_base_be_2` always returns digits maximally packed into `u64`s. + /// assert_eq!(n.to_base_be_2(10).collect::>(), [1234]); + /// assert_eq!(n.to_base_be_2(1000000).collect::>(), [1234]); + /// ``` #[inline] + #[track_caller] pub fn to_base_be(&self, base: u64) -> impl Iterator { // Use `to_base_le` if we can heap-allocate it to reverse the order, // as it only performs one division per iteration instead of two. @@ -80,6 +97,31 @@ impl Uint { } } + /// Returns an iterator over the base `base` digits of the number in + /// big-endian order. + /// + /// Always returns digits maximally packed into `u64`s. + /// Unlike [`to_base_be`], this method: + /// - never heap-allocates memory, so it's always faster + /// - always returns digits maximally packed into `u64`s, so passing the + /// constant base like `2`, `8`, instead of the highest power that fits in + /// u64 is not needed + /// + /// # Panics + /// + /// Panics if the base is less than 2. + /// + /// # Examples + /// + /// See [`to_base_be`]. + /// + /// [`to_base_be`]: Self::to_base_be + #[inline] + #[track_caller] + pub fn to_base_be_2(&self, base: u64) -> impl Iterator { + SpigotBig2::new(self.limbs, base) + } + /// Constructs the [`Uint`] from digits in the base `base` in little-endian. /// /// # Errors @@ -166,7 +208,7 @@ impl Uint { } // Multiply by base. // OPT: keep track of non-zero limbs and mul the minimum. - let mut carry: u128 = u128::from(digit); + let mut carry = u128::from(digit); #[allow(clippy::cast_possible_truncation)] for limb in &mut result.limbs { carry += u128::from(*limb) * u128::from(base); @@ -200,7 +242,7 @@ impl Iterator for SpigotLittle { type Item = u64; #[inline] - #[allow(clippy::cast_possible_truncation)] // Doesn't truncate + #[allow(clippy::cast_possible_truncation)] // Doesn't truncate. fn next(&mut self) -> Option { // Knuth Algorithm S. let mut zero: u64 = 0_u64; @@ -220,6 +262,8 @@ impl Iterator for SpigotLittle { } } +impl FusedIterator for SpigotLittle {} + /// Implementation of `to_base_be` when `alloc` feature is disabled. /// /// This is generally slower than simply reversing the result of `to_base_le` @@ -304,11 +348,84 @@ impl Iterator for SpigotBig #[cfg(not(feature = "alloc"))] impl core::iter::FusedIterator for SpigotBig {} +/// An iterator over the base `base` digits of the number in big-endian order. +/// +/// See [`Uint::to_base_be_2`] for more details. +struct SpigotBig2 { + buf: SpigotBuf, +} + +impl SpigotBig2 { + #[inline] + #[track_caller] + fn new(limbs: [u64; LIMBS], base: u64) -> Self { + Self { + buf: SpigotBuf::new(limbs, base), + } + } +} + +impl Iterator for SpigotBig2 { + type Item = u64; + + #[inline] + fn next(&mut self) -> Option { + self.buf.next_back() + } +} + +impl FusedIterator for SpigotBig2 {} + +/// Collects [`SpigotLittle`] into a stack-allocated buffer. +/// +/// Base for [`SpigotBig2`]. +struct SpigotBuf { + end: usize, + buf: [[MaybeUninit; 2]; LIMBS], +} + +impl SpigotBuf { + #[inline] + #[track_caller] + fn new(limbs: [u64; LIMBS], mut base: u64) -> Self { + // We need to do this so we can guarantee that `buf` is big enough. + base = crate::utils::max_pow_u64(base); + + let mut buf = [[MaybeUninit::uninit(); 2]; LIMBS]; + // TODO(MSRV-1.80): let as_slice = buf.as_flattened_mut(); + let as_slice = unsafe { + core::slice::from_raw_parts_mut(buf.as_mut_ptr().cast::>(), LIMBS * 2) + }; + let mut init = 0; + for (i, limb) in SpigotLittle::new(limbs, base).enumerate() { + debug_assert!( + i < as_slice.len(), + "base {base} too small for u64 digits of {LIMBS} limbs; this shouldn't happen \ + because of the `max_pow_u64` call above" + ); + unsafe { as_slice.get_unchecked_mut(i).write(limb) }; + init += 1; + } + Self { end: init, buf } + } + + #[inline] + fn next_back(&mut self) -> Option { + if self.end == 0 { + None + } else { + self.end -= 1; + Some(unsafe { *self.buf.as_ptr().cast::().add(self.end) }) + } + } +} + #[cfg(test)] #[allow(clippy::unreadable_literal)] #[allow(clippy::zero_prefixed_literal)] mod tests { use super::*; + use crate::utils::max_pow_u64; // 90630363884335538722706632492458228784305343302099024356772372330524102404852 const N: Uint<256, 4> = Uint::from_limbs([ @@ -377,6 +494,26 @@ mod tests { ); } + #[test] + fn test_to_base_be_2() { + assert_eq!( + Uint::<64, 1>::from(123456789) + .to_base_be_2(10) + .collect::>(), + vec![123456789] + ); + assert_eq!( + N.to_base_be_2(10000000000000000000_u64).collect::>(), + vec![ + 9, + 0630363884335538722, + 7066324924582287843, + 0534330209902435677, + 2372330524102404852 + ] + ); + } + #[test] fn test_from_base_be() { assert_eq!( @@ -435,6 +572,10 @@ mod tests { let digits = n.to_base_be(base); let n2 = Uint::::from_base_be(base, digits).unwrap(); assert_eq!(n, n2); + + let digits = n.to_base_be_2(base).collect::>(); + let n2 = Uint::::from_base_be(max_pow_u64(base), digits).unwrap(); + assert_eq!(n, n2); } let single = |x: u64| x..=x; diff --git a/src/fmt.rs b/src/fmt.rs index a70ce127..2465df9a 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -58,7 +58,7 @@ macro_rules! impl_fmt { // Use `BITS` for all bases since `generic_const_exprs` is not yet stable. let mut buffer = DisplayBuffer::::new(); let mut first = true; - for spigot in self.to_base_be(<$base>::MAX) { + for spigot in self.to_base_be_2(<$base>::MAX) { write!( buffer, concat!("{:0width$", $base_char, "}"), @@ -87,16 +87,16 @@ impl_fmt!(fmt::LowerHex; base::Hexadecimal, "x"); impl_fmt!(fmt::UpperHex; base::Hexadecimal, "X"); struct DisplayBuffer { - buf: [MaybeUninit; SIZE], len: usize, + buf: [MaybeUninit; SIZE], } impl DisplayBuffer { #[inline] const fn new() -> Self { Self { - buf: unsafe { MaybeUninit::uninit().assume_init() }, len: 0, + buf: unsafe { MaybeUninit::uninit().assume_init() }, } }