Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use self::{
add::{adc_n, cmp, sbb_n},
div::div,
gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, submul_nx1},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, mul_nx1, submul_nx1},
ops::{adc, sbb},
shift::{shift_left_small, shift_right_small},
};
Expand Down
13 changes: 12 additions & 1 deletion src/algorithms/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,18 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
prod.high()
}

/// Computes `lhs += a * b` and returns the borrow.
/// Computes `lhs *= a` and returns the carry.
pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
let mut carry = 0;
for lhs in &mut *lhs {
let product = u128::muladd(*lhs, a, carry);
*lhs = product.low();
carry = product.high();
}
carry
}

/// Computes `lhs += a * b` and returns the carry.
///
/// Requires `lhs.len() == a.len()`.
///
Expand Down
2 changes: 1 addition & 1 deletion src/aliases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub mod tests {
use super::*;

#[test]
fn instantiate_consts() {
const fn instantiate_consts() {
let _ = (U0::ZERO, U0::MAX, B0::ZERO);
let _ = (U1::ZERO, U1::MAX, B1::ZERO);
let _ = (U8::ZERO, U8::MAX, B8::ZERO);
Expand Down
116 changes: 61 additions & 55 deletions src/base_convert.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::Uint;
use crate::{
algorithms::{addmul_nx1, mul_nx1},
Uint,
};
use core::fmt;

/// Error for [`from_base_le`][Uint::from_base_le] and
Expand Down Expand Up @@ -85,29 +88,6 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
}
}

/// Adds a digit in base `base` to the number. This is used internally by
/// [`Uint::from_base_le`] and [`Uint::from_base_be`].
#[inline]
fn add_digit(&mut self, digit: u64, base: u64) -> Result<(), BaseConvertError> {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
// Multiply by base.
// OPT: keep track of non-zero limbs and mul the minimum.
let mut carry: u128 = u128::from(digit);
#[allow(clippy::cast_possible_truncation)]
for limb in &mut self.limbs {
carry += u128::from(*limb) * u128::from(base);
*limb = carry as u64;
carry >>= 64;
}
if carry > 0 || (LIMBS != 0 && self.limbs[LIMBS - 1] > Self::MASK) {
return Err(BaseConvertError::Overflow);
}

Ok(())
}

/// Constructs the [`Uint`] from digits in the base `base` in little-endian.
///
/// # Errors
Expand All @@ -124,36 +104,48 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
if base < 2 {
return Err(BaseConvertError::InvalidBase(base));
}

let mut tail = digits.into_iter();
match tail.next() {
Some(digit) => Self::from_base_le_recurse(digit, base, &mut tail),
None => Ok(Self::ZERO),
if BITS == 0 {
for digit in digits {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
if digit != 0 {
return Err(BaseConvertError::Overflow);
}
}
return Ok(Self::ZERO);
}
}

/// This is the recursive part of [`Uint::from_base_le`].
///
/// We drain the iterator via the recursive calls, and then perform the
/// same construction loop as [`Uint::from_base_be`] while exiting the
/// recursive callstack.
#[inline]
fn from_base_le_recurse<I: Iterator<Item = u64>>(
digit: u64,
base: u64,
tail: &mut I,
) -> Result<Self, BaseConvertError> {
if digit > base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
let mut iter = digits.into_iter();
let mut result = Self::ZERO;
let mut power = Self::from(1);
for digit in iter.by_ref() {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}

let mut acc = match tail.next() {
Some(digit) => Self::from_base_le_recurse::<I>(digit, base, tail)?,
None => Self::ZERO,
};
// Add digit to result
let overflow = addmul_nx1(&mut result.limbs, &power.limbs, digit);
if overflow != 0 || result.limbs[LIMBS - 1] > Self::MASK {
return Err(BaseConvertError::Overflow);
}

acc.add_digit(digit, base)?;
Ok(acc)
// Update power
let overflow = mul_nx1(&mut power.limbs, base);
if overflow != 0 || power.limbs[LIMBS - 1] > Self::MASK {
// Following digits must be zero
break;
}
}
for digit in iter {
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
if digit != 0 {
return Err(BaseConvertError::Overflow);
}
}
Ok(result)
}

/// Constructs the [`Uint`] from digits in the base `base` in big-endian.
Expand All @@ -178,7 +170,21 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {

let mut result = Self::ZERO;
for digit in digits {
result.add_digit(digit, base)?;
if digit >= base {
return Err(BaseConvertError::InvalidDigit(digit, base));
}
// Multiply by base.
// OPT: keep track of non-zero limbs and mul the minimum.
let mut carry: u128 = u128::from(digit);
#[allow(clippy::cast_possible_truncation)]
for limb in &mut result.limbs {
carry += u128::from(*limb) * u128::from(base);
*limb = carry as u64;
carry >>= 64;
}
if carry > 0 || (LIMBS != 0 && result.limbs[LIMBS - 1] > Self::MASK) {
return Err(BaseConvertError::Overflow);
}
}

Ok(result)
Expand Down Expand Up @@ -309,20 +315,20 @@ mod tests {
#[test]
fn test_from_base_be_overflow() {
assert_eq!(
Uint::<0, 0>::from_base_be(10, [].into_iter()),
Uint::<0, 0>::from_base_be(10, std::iter::empty()),
Ok(Uint::<0, 0>::ZERO)
);
assert_eq!(
Uint::<0, 0>::from_base_be(10, [0].into_iter()),
Uint::<0, 0>::from_base_be(10, std::iter::once(0)),
Ok(Uint::<0, 0>::ZERO)
);
assert_eq!(
Uint::<0, 0>::from_base_be(10, [1].into_iter()),
Uint::<0, 0>::from_base_be(10, std::iter::once(1)),
Err(BaseConvertError::Overflow)
);
assert_eq!(
Uint::<1, 1>::from_base_be(10, [1, 0, 0].into_iter()),
Err(BaseConvertError::Overflow)
)
);
}
}