diff --git a/CHANGELOG.md b/CHANGELOG.md index 755c95f0..0486f294 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `bigdecimal` support ([#486]) - `PartialEq` and `PartialOrd` implementations for primitive integers; minor breaking change for type inference ([#491]) +### Changed + +- `to_base_be` and `core::fmt` trait implementations are available without the "alloc" feature ([#488]) + ### Fixed - Check limb overflow in shift ops ([#476]) @@ -22,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#476]: https://github.com/recmo/uint/pull/476 [#483]: https://github.com/recmo/uint/pull/483 [#486]: https://github.com/recmo/uint/pull/486 +[#488]: https://github.com/recmo/uint/pull/488 [#491]: https://github.com/recmo/uint/pull/491 ## [1.15.0] - 2025-05-22 diff --git a/src/base_convert.rs b/src/base_convert.rs index 53f7d544..9f70a94f 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -50,11 +50,7 @@ impl Uint { /// Panics if the base is less than 2. #[inline] pub fn to_base_le(&self, base: u64) -> impl Iterator { - assert!(base > 1); - SpigotLittle { - base, - limbs: self.limbs, - } + SpigotLittle::new(self.limbs, base) } /// Returns an iterator over the base `base` digits of the number in @@ -68,24 +64,19 @@ impl Uint { /// /// Panics if the base is less than 2. #[inline] - #[cfg(feature = "alloc")] // OPT: Find an allocation free method. Maybe extract from the top? pub fn to_base_be(&self, base: u64) -> impl Iterator { - struct OwnedVecIterator { - vec: alloc::vec::Vec, - } - - impl Iterator for OwnedVecIterator { - type Item = u64; - - #[inline] - fn next(&mut self) -> Option { - self.vec.pop() - } + // Use `to_base_le` if we can heap-allocate it to reverse the order, + // as it only performs one division per iteration instead of two. + #[cfg(feature = "alloc")] + { + self.to_base_le(base) + .collect::>() + .into_iter() + .rev() } - - assert!(base > 1); - OwnedVecIterator { - vec: self.to_base_le(base).collect(), + #[cfg(not(feature = "alloc"))] + { + SpigotBig::new(*self, base) } } @@ -196,6 +187,15 @@ struct SpigotLittle { limbs: [u64; LIMBS], } +impl SpigotLittle { + #[inline] + #[track_caller] + fn new(limbs: [u64; LIMBS], base: u64) -> Self { + assert!(base > 1); + Self { base, limbs } + } +} + impl Iterator for SpigotLittle { type Item = u64; @@ -220,6 +220,90 @@ impl Iterator for SpigotLittle { } } +/// Implementation of `to_base_be` when `alloc` feature is disabled. +/// +/// This is generally slower than simply reversing the result of `to_base_le` +/// as it performs two divisions per iteration instead of one. +#[cfg(not(feature = "alloc"))] +struct SpigotBig { + base: u64, + n: Uint, + power: Uint, + done: bool, +} + +#[cfg(not(feature = "alloc"))] +impl SpigotBig { + #[inline] + #[track_caller] + fn new(n: Uint, base: u64) -> Self { + assert!(base > 1); + + Self { + n, + base, + power: Self::highest_power(n, base), + done: false, + } + } + + /// Returns the largest power of `base` that fits in `n`. + #[inline] + fn highest_power(n: Uint, base: u64) -> Uint { + let mut power = Uint::ONE; + if base.is_power_of_two() { + loop { + match power.checked_shl(base.trailing_zeros() as _) { + Some(p) if p < n => power = p, + _ => break, + } + } + } else if let Ok(base) = Uint::try_from(base) { + loop { + match power.checked_mul(base) { + Some(p) if p < n => power = p, + _ => break, + } + } + } + power + } +} + +#[cfg(not(feature = "alloc"))] +impl Iterator for SpigotBig { + type Item = u64; + + #[inline] + fn next(&mut self) -> Option { + if self.done { + return None; + } + + let digit; + if self.power == 1 { + digit = self.n; + self.done = true; + } else if self.base.is_power_of_two() { + digit = self.n >> self.power.trailing_zeros(); + self.n &= self.power - Uint::ONE; + + self.power >>= self.base.trailing_zeros(); + } else { + (digit, self.n) = self.n.div_rem(self.power); + self.power /= Uint::from(self.base); + } + + match u64::try_from(digit) { + Ok(digit) => Some(digit), + Err(e) => debug_unreachable!("digit {digit}: {e}"), + } + } +} + +#[cfg(not(feature = "alloc"))] +impl core::iter::FusedIterator for SpigotBig {} + #[cfg(test)] #[allow(clippy::unreadable_literal)] #[allow(clippy::zero_prefixed_literal)] @@ -331,4 +415,36 @@ mod tests { Err(BaseConvertError::Overflow) ); } + + #[test] + fn test_roundtrip() { + fn test(n: Uint, base: u64) { + assert_eq!( + n.to_base_be(base).collect::>(), + n.to_base_le(base) + .collect::>() + .into_iter() + .rev() + .collect::>(), + ); + + let digits = n.to_base_le(base); + let n2 = Uint::::from_base_le(base, digits).unwrap(); + assert_eq!(n, n2); + + let digits = n.to_base_be(base); + let n2 = Uint::::from_base_be(base, digits).unwrap(); + assert_eq!(n, n2); + } + + let single = |x: u64| x..=x; + for base in [2..=129, single(1 << 31), single(1 << 32), single(1 << 33)] + .into_iter() + .flatten() + { + test(Uint::<64, 1>::from(123456789), base); + test(Uint::<128, 2>::from(123456789), base); + test(N, base); + } + } } diff --git a/src/fmt.rs b/src/fmt.rs index 568129fc..a70ce127 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -1,5 +1,4 @@ #![allow(clippy::missing_inline_in_public_items)] // allow format functions -#![cfg(feature = "alloc")] use crate::Uint; use core::{ @@ -9,102 +8,83 @@ use core::{ mod base { pub(super) trait Base { + /// The base. + const BASE: u64; + /// The prefix for the base. + const PREFIX: &'static str; + /// Highest power of the base that fits in a `u64`. - const MAX: u64; + const MAX: u64 = crate::utils::max_pow_u64(Self::BASE); /// Number of characters written using `MAX` as the base in /// `to_base_be`. - /// - /// This is `MAX.log(base)`. + // TODO(MSRV-1.67): = `Self::MAX.ilog(Self::BASE)` const WIDTH: usize; - /// The prefix for the base. - const PREFIX: &'static str; } pub(super) struct Binary; impl Base for Binary { - const MAX: u64 = 1 << 63; - const WIDTH: usize = 63; + const BASE: u64 = 2; const PREFIX: &'static str = "0b"; + const WIDTH: usize = 63; } pub(super) struct Octal; impl Base for Octal { - const MAX: u64 = 1 << 63; - const WIDTH: usize = 21; + const BASE: u64 = 8; const PREFIX: &'static str = "0o"; + const WIDTH: usize = 21; } pub(super) struct Decimal; impl Base for Decimal { - const MAX: u64 = 10_000_000_000_000_000_000; - const WIDTH: usize = 19; + const BASE: u64 = 10; const PREFIX: &'static str = ""; + const WIDTH: usize = 19; } pub(super) struct Hexadecimal; impl Base for Hexadecimal { - const MAX: u64 = 1 << 60; - const WIDTH: usize = 15; + const BASE: u64 = 16; const PREFIX: &'static str = "0x"; + const WIDTH: usize = 15; } } use base::Base; -macro_rules! write_digits { - ($self:expr, $f:expr; $base:ty, $base_char:literal) => { - if LIMBS == 0 || $self.is_zero() { - return $f.pad_integral(true, <$base>::PREFIX, "0"); +macro_rules! impl_fmt { + ($tr:path; $base:ty, $base_char:literal) => { + impl $tr for Uint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Use `BITS` for all bases since `generic_const_exprs` is not yet stable. + let mut buffer = DisplayBuffer::::new(); + let mut first = true; + for spigot in self.to_base_be(<$base>::MAX) { + write!( + buffer, + concat!("{:0width$", $base_char, "}"), + spigot, + width = if first { 0 } else { <$base>::WIDTH }, + ) + .unwrap(); + first = false; + } + f.pad_integral(true, <$base>::PREFIX, buffer.as_str()) + } } - // Use `BITS` for all bases since `generic_const_exprs` is not yet stable. - let mut buffer = DisplayBuffer::::new(); - for (i, spigot) in $self.to_base_be(<$base>::MAX).enumerate() { - write!( - buffer, - concat!("{:0width$", $base_char, "}"), - spigot, - width = if i == 0 { 0 } else { <$base>::WIDTH }, - ) - .unwrap(); - } - return $f.pad_integral(true, <$base>::PREFIX, buffer.as_str()); }; } -impl fmt::Display for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write_digits!(self, f; base::Decimal, ""); - } -} - impl fmt::Debug for Uint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(self, f) } } -impl fmt::Binary for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write_digits!(self, f; base::Binary, "b"); - } -} - -impl fmt::Octal for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write_digits!(self, f; base::Octal, "o"); - } -} - -impl fmt::LowerHex for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write_digits!(self, f; base::Hexadecimal, "x"); - } -} - -impl fmt::UpperHex for Uint { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write_digits!(self, f; base::Hexadecimal, "X"); - } -} +impl_fmt!(fmt::Display; base::Decimal, ""); +impl_fmt!(fmt::Binary; base::Binary, "b"); +impl_fmt!(fmt::Octal; base::Octal, "o"); +impl_fmt!(fmt::LowerHex; base::Hexadecimal, "x"); +impl_fmt!(fmt::UpperHex; base::Hexadecimal, "X"); struct DisplayBuffer { buf: [MaybeUninit; SIZE], diff --git a/src/utils.rs b/src/utils.rs index 23e512da..7ad69694 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -33,6 +33,26 @@ pub(crate) fn trim_end_vec(vec: &mut Vec, value: &T) { vec.truncate(last_idx(vec, value)); } +/// Returns the highest power of `n` that fits in `u64`. +#[inline] +pub(crate) const fn max_pow_u64(n: u64) -> u64 { + match n { + 2 | 8 => 1 << 63, + 10 => 10_000_000_000_000_000_000, + 16 => 1 << 60, + _ => max_pow_u64_impl(n), + } +} + +#[inline] +const fn max_pow_u64_impl(n: u64) -> u64 { + let mut max = n; + while let Some(next) = max.checked_mul(n) { + max = next; + } + max +} + // Branch prediction hints. #[cfg(feature = "nightly")] pub(crate) use core::intrinsics::{likely, unlikely}; @@ -68,4 +88,17 @@ mod tests { assert_eq!(trim_end_vec(vec![0, 1, 0, 0, 0], &0), &[0, 1]); assert_eq!(trim_end_vec(vec![0, 1, 0, 1, 0], &0), &[0, 1, 0, 1]); } + + #[test] + fn test_max_pow_u64() { + for (n, expected) in [ + (2, 1 << 63), + (8, 1 << 63), + (10, 10_000_000_000_000_000_000), + (16, 1 << 60), + ] { + assert_eq!(max_pow_u64(n), expected); + assert_eq!(max_pow_u64_impl(n), expected); + } + } }