diff --git a/src/base_convert.rs b/src/base_convert.rs index 3c21d162..8680acd4 100644 --- a/src/base_convert.rs +++ b/src/base_convert.rs @@ -125,9 +125,24 @@ impl Uint { return Err(BaseConvertError::InvalidBase(base)); } - let mut tail = digits.into_iter(); - match tail.next() { - Some(digit) => Self::from_base_le_recurse(digit, base, &mut tail), + // we compute this by dividing the number of bits in the Uint by the + // number of bits in each word, rounding up. + let max_digits: usize = BITS / (usize::BITS - base.leading_zeros()) as usize + 1; + + let mut iter = digits.into_iter(); + let mut digits = iter.by_ref().take(max_digits); + + match digits.next() { + Some(digit) => { + let res = Self::from_base_le_recurse::(digit, base, digits)?; + + // if the iter was not drained during the recursion process, + // then we have overflowed + if iter.next().is_some() { + return Err(BaseConvertError::Overflow); + } + Ok(res) + } None => Ok(Self::ZERO), } } @@ -141,7 +156,7 @@ impl Uint { fn from_base_le_recurse>( digit: u64, base: u64, - tail: &mut I, + mut tail: core::iter::Take<&mut I>, ) -> Result { if digit > base { return Err(BaseConvertError::InvalidDigit(digit, base)); @@ -176,11 +191,23 @@ impl Uint { return Err(BaseConvertError::InvalidBase(base)); } + // we compute this by dividing the number of bits in the Uint by the + // number of bits in each word, rounding up. + let max_digits: usize = BITS / (usize::BITS - base.leading_zeros()) as usize + 1; + + let mut iter = digits.into_iter(); + let digits = iter.by_ref().take(max_digits); + let mut result = Self::ZERO; for digit in digits { result.add_digit(digit, base)?; } + // If the iterator still contains digits, we have overflowed. + if iter.next().is_some() { + return Err(BaseConvertError::Overflow); + } + Ok(result) } } @@ -265,6 +292,19 @@ mod tests { .unwrap(), N ); + + assert_eq!( + Uint::<4, 1>::from_base_le(10, [6, 1]), + Err(BaseConvertError::Overflow), + ); + assert_eq!( + Uint::<4, 1>::from_base_le(10, [5, 1, 318]), + Err(BaseConvertError::Overflow), + ); + assert_eq!( + Uint::<4, 1>::from_base_le(10, [5, 1]), + Ok(Uint::<4, 1>::from(15)), + ); } #[test] @@ -308,6 +348,10 @@ mod tests { #[test] fn test_from_base_be_overflow() { + assert_eq!( + Uint::<1, 1>::from_base_be(10, std::iter::repeat(0)), + Err(BaseConvertError::Overflow) + ); assert_eq!( Uint::<0, 0>::from_base_be(10, [].into_iter()), Ok(Uint::<0, 0>::ZERO) diff --git a/src/string.rs b/src/string.rs index 73f858ef..425b6b56 100644 --- a/src/string.rs +++ b/src/string.rs @@ -255,4 +255,9 @@ mod tests { assert_eq!(format!("{n:#b}"), format!("{value:#066b}")); }); } + + #[test] + fn test_from_hex_extra_zeroes() { + Uint::<16, 1>::from_str("0x00000000000001").unwrap(); + } }