diff --git a/benches/uint.rs b/benches/uint.rs index 0c10425a1..9c5df6c38 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -18,6 +18,18 @@ fn bench_division(c: &mut Criterion) { ) }); + group.bench_function("div/rem_vartime, U256/U128, full size", |b| { + b.iter_batched( + || { + let x = U256::random(&mut OsRng); + let y: U256 = (U128::MAX, U128::ZERO).into(); + (x, NonZero::new(y).unwrap()) + }, + |(x, y)| black_box(x.div_rem_vartime(&y)), + BatchSize::SmallInput, + ) + }); + group.bench_function("rem, U256/U128, full size", |b| { b.iter_batched( || { @@ -35,8 +47,7 @@ fn bench_division(c: &mut Criterion) { b.iter_batched( || { let x = U256::random(&mut OsRng); - let y_half = U128::random(&mut OsRng); - let y: U256 = (y_half, U128::ZERO).into(); + let y: U256 = (U128::MAX, U128::ZERO).into(); (x, NonZero::new(y).unwrap()) }, |(x, y)| black_box(x.rem_vartime(&y)), @@ -44,6 +55,18 @@ fn bench_division(c: &mut Criterion) { ) }); + group.bench_function("rem_wide_vartime, U256", |b| { + b.iter_batched( + || { + let (x_lo, x_hi) = (U256::random(&mut OsRng), U256::random(&mut OsRng)); + let y: U256 = (U128::MAX, U128::ZERO).into(); + (x_lo, x_hi, NonZero::new(y).unwrap()) + }, + |(x_lo, x_hi, y)| black_box(Uint::rem_wide_vartime((x_lo, x_hi), &y)), + BatchSize::SmallInput, + ) + }); + group.bench_function("div/rem, U256/Limb, full size", |b| { b.iter_batched( || { diff --git a/src/uint/div.rs b/src/uint/div.rs index 7eeb5ee95..32e838943 100644 --- a/src/uint/div.rs +++ b/src/uint/div.rs @@ -1,7 +1,12 @@ //! [`Uint`] division operations. -use super::div_limb::{div_rem_limb_with_reciprocal, rem_limb_with_reciprocal, Reciprocal}; -use crate::{CheckedDiv, ConstChoice, DivRemLimb, Limb, NonZero, RemLimb, Uint, Word, Wrapping}; +use super::div_limb::{ + div2by1, div_rem_limb_with_reciprocal, rem_limb_with_reciprocal, rem_limb_with_reciprocal_wide, + Reciprocal, +}; +use crate::{ + CheckedDiv, ConstChoice, DivRemLimb, Limb, NonZero, RemLimb, Uint, WideWord, Word, Wrapping, +}; use core::ops::{Div, DivAssign, Rem, RemAssign}; use subtle::CtOption; @@ -69,29 +74,147 @@ impl Uint { /// /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. - #[allow(trivial_numeric_casts)] - pub const fn div_rem_vartime(&self, rhs: &NonZero) -> (Self, Self) { - let mb = rhs.0.bits_vartime(); - let mut bd = Self::BITS - mb; - let mut rem = *self; - let mut quo = Self::ZERO; - // If there is overflow, it means `mb == 0`, so `rhs == 0`. - let mut c = rhs.0.wrapping_shl_vartime(bd); + pub const fn div_rem_vartime( + &self, + rhs: &NonZero>, + ) -> (Self, Uint) { + // Based on Section 4.3.1, of The Art of Computer Programming, Volume 2, by Donald E. Knuth. + // Further explanation at https://janmr.com/blog/2014/04/basic-multiple-precision-long-division/ + + let dbits = rhs.0.bits_vartime(); + let yc = ((dbits + Limb::BITS - 1) / Limb::BITS) as usize; + + // Short circuit for small or extra large divisors + match yc { + 1 => { + // If the divisor is a single limb, use limb division + let (q, r) = div_rem_limb_with_reciprocal( + self, + &Reciprocal::new(rhs.0.limbs[0].to_nz().expect("zero divisor")), + ); + return (q, Uint::from_word(r.0)); + } + yc if yc > LIMBS => { + // Divisor is greater than dividend. Return zero and the dividend as the + // quotient and remainder + return (Uint::ZERO, self.resize()); + } + _ => {} + }; + + let lshift = (Limb::BITS - (dbits % Limb::BITS)) % Limb::BITS; + let rshift = if lshift == 0 { 0 } else { Limb::BITS - lshift }; + let mut x = self.to_limbs(); + let mut x_hi = Limb::ZERO; + let mut xi = LIMBS - 1; + let mut y = rhs.0.to_limbs(); + let mut i; + let mut carry; + + if lshift != 0 { + // Shift divisor such that it has no leading zeros + // This means that div2by1 requires no extra shifts, and ensures that the high word >= b/2 + i = 0; + carry = Limb::ZERO; + while i < yc { + (y[i], carry) = (Limb((y[i].0 << lshift) | carry.0), Limb(y[i].0 >> rshift)); + i += 1; + } + + // Shift the dividend to match + i = 0; + carry = Limb::ZERO; + while i < LIMBS { + (x[i], carry) = (Limb((x[i].0 << lshift) | carry.0), Limb(x[i].0 >> rshift)); + i += 1; + } + x_hi = carry; + } + + let reciprocal = Reciprocal::new(y[yc - 1].to_nz().expect("zero divisor")); loop { - let (r, borrow) = rem.sbb(&c, Limb::ZERO); - let choice = ConstChoice::from_word_mask(borrow.0); - rem = Self::select(&r, &rem, choice); - quo = Self::select(&quo.bitor(&Self::ONE), &quo, choice); - if bd == 0 { + // Divide high dividend words by the high divisor word to estimate the quotient word + let (mut quo, mut rem) = div2by1(x_hi.0, x[xi].0, &reciprocal); + + i = 0; + while i < 2 { + let qy = (quo as WideWord) * (y[yc - 2].0 as WideWord); + let rx = ((rem as WideWord) << Word::BITS) | (x[xi - 1].0 as WideWord); + // Constant-time check for q*y[-2] < r*x[-1], based on ConstChoice::from_word_lt + let diff = ConstChoice::from_word_lsb( + ((((!rx) & qy) | (((!rx) | qy) & (rx.wrapping_sub(qy)))) + >> (WideWord::BITS - 1)) as Word, + ); + quo = diff.select_word(quo, quo.saturating_sub(1)); + rem = diff.select_word(rem, rem.saturating_add(y[yc - 1].0)); + i += 1; + } + + // Subtract q*divisor from the dividend + carry = Limb::ZERO; + let mut borrow = Limb::ZERO; + let mut tmp; + i = 0; + while i < yc { + (tmp, carry) = Limb::ZERO.mac(y[i], Limb(quo), carry); + (x[xi + i + 1 - yc], borrow) = x[xi + i + 1 - yc].sbb(tmp, borrow); + i += 1; + } + (_, borrow) = x_hi.sbb(carry, borrow); + + // If the subtraction borrowed, then decrement q and add back the divisor + // The probability of this being needed is very low, about 2/(Limb::MAX+1) + let ct_borrow = ConstChoice::from_word_mask(borrow.0); + carry = Limb::ZERO; + i = 0; + while i < yc { + (x[xi + i + 1 - yc], carry) = + x[xi + i + 1 - yc].adc(Limb::select(Limb::ZERO, y[i], ct_borrow), carry); + i += 1; + } + quo = ct_borrow.select_word(quo, quo.saturating_sub(1)); + + // Store the quotient within dividend and set x_hi to the current highest word + x_hi = x[xi]; + x[xi] = Limb(quo); + + if xi == yc - 1 { break; } - bd -= 1; - c = c.shr1(); - quo = quo.shl1(); + xi -= 1; } - (quo, rem) + // Copy the remainder to divisor + i = 0; + while i < yc - 1 { + y[i] = x[i]; + i += 1; + } + y[yc - 1] = x_hi; + + // Unshift the remainder from the earlier adjustment + if lshift != 0 { + i = yc; + carry = Limb::ZERO; + while i > 0 { + i -= 1; + (y[i], carry) = (Limb((y[i].0 >> lshift) | carry.0), Limb(y[i].0 << rshift)); + } + } + + // Shift the quotient to the low limbs within dividend + i = 0; + while i < LIMBS { + if i <= (LIMBS - yc) { + x[i] = x[i + yc - 1]; + } else { + x[i] = Limb::ZERO; + } + i += 1; + } + + (Uint::new(x), Uint::new(y)) } /// Computes `self` % `rhs`, returns the remainder. @@ -104,22 +227,7 @@ impl Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. pub const fn rem_vartime(&self, rhs: &NonZero) -> Self { - let mb = rhs.0.bits_vartime(); - let mut bd = Self::BITS - mb; - let mut rem = *self; - let mut c = rhs.0.wrapping_shl_vartime(bd); - - loop { - let (r, borrow) = rem.sbb(&c, Limb::ZERO); - rem = Self::select(&r, &rem, ConstChoice::from_word_mask(borrow.0)); - if bd == 0 { - break; - } - bd -= 1; - c = c.shr1(); - } - - rem + self.div_rem_vartime(rhs).1 } /// Computes `self` % `rhs`, returns the remainder. @@ -129,32 +237,128 @@ impl Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. pub const fn rem_wide_vartime(lower_upper: (Self, Self), rhs: &NonZero) -> Self { - let mb = rhs.0.bits_vartime(); + let dbits = rhs.0.bits_vartime(); + let yc = ((dbits + Limb::BITS - 1) / Limb::BITS) as usize; + + // If the divisor is a single limb, use limb division + if yc == 1 { + let r = rem_limb_with_reciprocal_wide( + (&lower_upper.0, &lower_upper.1), + &Reciprocal::new(rhs.0.limbs[0].to_nz().expect("zero divisor")), + ); + return Uint::from_word(r.0); + } - // The number of bits to consider is two sets of limbs * BITS - mb (modulus bitcount) - let mut bd = (2 * Self::BITS) - mb; + let lshift = (Limb::BITS - (dbits % Limb::BITS)) % Limb::BITS; + let rshift = if lshift == 0 { 0 } else { Limb::BITS - lshift }; + let mut x = lower_upper.1.to_limbs(); // high limbs + let mut x_hi = Limb::ZERO; + let mut xi = LIMBS - 1; + let mut y = rhs.0.to_limbs(); + let mut extra_limbs = LIMBS; + let mut i; + let mut carry; + + if lshift != 0 { + // Shift divisor such that it has no leading zeros + // This ensures that the high word >= b/2, and means that div2by1 requires no extra shifts + i = 0; + carry = Limb::ZERO; + while i < yc { + (y[i], carry) = (Limb((y[i].0 << lshift) | carry.0), Limb(y[i].0 >> rshift)); + i += 1; + } - // The wide integer to reduce, split into two halves - let (mut lower, mut upper) = lower_upper; + // Shift the dividend to match + i = 0; + carry = Limb(lower_upper.0.limbs[LIMBS - 1].0 >> rshift); + while i < LIMBS { + (x[i], carry) = (Limb((x[i].0 << lshift) | carry.0), Limb(x[i].0 >> rshift)); + i += 1; + } + x_hi = carry; + } - // Factor of the modulus, split into two halves - let mut c = Self::overflowing_shl_vartime_wide((rhs.0, Uint::ZERO), bd) - .expect("shift within range"); + let reciprocal = Reciprocal::new(y[yc - 1].to_nz().expect("zero divisor")); loop { - let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO); - let (upper_sub, borrow) = upper.sbb(&c.1, borrow); + // Divide high dividend words by the high divisor word to estimate the quotient word + let (mut quo, mut rem) = div2by1(x_hi.0, x[xi].0, &reciprocal); + + i = 0; + while i < 2 { + let qy = (quo as WideWord) * (y[yc - 2].0 as WideWord); + let rx = ((rem as WideWord) << Word::BITS) | (x[xi - 1].0 as WideWord); + // Constant-time check for q*y[-2] < r*x[-1], based on ConstChoice::from_word_lt + let diff = ConstChoice::from_word_lsb( + ((((!rx) & qy) | (((!rx) | qy) & (rx.wrapping_sub(qy)))) + >> (WideWord::BITS - 1)) as Word, + ); + quo = diff.select_word(quo, quo.saturating_sub(1)); + rem = diff.select_word(rem, rem.saturating_add(y[yc - 1].0)); + i += 1; + } - lower = Self::select(&lower_sub, &lower, ConstChoice::from_word_mask(borrow.0)); - upper = Self::select(&upper_sub, &upper, ConstChoice::from_word_mask(borrow.0)); - if bd == 0 { - break; + // Subtract q*divisor from the dividend + carry = Limb::ZERO; + let mut borrow = Limb::ZERO; + let mut tmp; + i = 0; + while i < yc { + (tmp, carry) = Limb::ZERO.mac(y[i], Limb(quo), carry); + (x[xi + i + 1 - yc], borrow) = x[xi + i + 1 - yc].sbb(tmp, borrow); + i += 1; + } + (_, borrow) = x_hi.sbb(carry, borrow); + + // If the subtraction borrowed, then add back the divisor + // The probability of this being needed is very low, about 2/(Limb::MAX+1) + let ct_borrow = ConstChoice::from_word_mask(borrow.0); + carry = Limb::ZERO; + i = 0; + while i < yc { + (x[xi + i + 1 - yc], carry) = + x[xi + i + 1 - yc].adc(Limb::select(Limb::ZERO, y[i], ct_borrow), carry); + i += 1; + } + + // Set x_hi to the current highest word + x_hi = x[xi]; + + // If we have lower limbs remaining, shift the divisor words one word left + if extra_limbs > 0 { + extra_limbs -= 1; + i = LIMBS - 1; + while i > 0 { + x[i] = x[i - 1]; + i -= 1; + } + x[0] = lower_upper.0.limbs[extra_limbs]; + if lshift != 0 { + x[0].0 <<= lshift; + if extra_limbs > 0 { + x[0].0 |= lower_upper.0.limbs[extra_limbs - 1].0 >> rshift; + } + } + } else { + if xi == yc - 1 { + break; + } + xi -= 1; + } + } + + // Unshift the remainder from the earlier adjustment + if lshift != 0 { + i = yc; + carry = Limb::ZERO; + while i > 0 { + i -= 1; + (x[i], carry) = (Limb((x[i].0 >> lshift) | carry.0), Limb(x[i].0 << rshift)); } - bd -= 1; - c = Self::overflowing_shr_vartime_wide(c, 1).expect("shift within range"); } - lower + Uint::new(x) } /// Computes `self` % 2^k. Faster than reduce since its a power of 2. @@ -200,17 +404,15 @@ impl Uint { /// There’s no way wrapping could ever happen. /// This function exists, so that all operations are accounted for in the wrapping operations. pub const fn wrapping_div(&self, rhs: &NonZero) -> Self { - let (q, _) = self.div_rem(rhs); - q + self.div_rem(rhs).0 } /// Wrapped division is just normal division i.e. `self` / `rhs` /// /// There’s no way wrapping could ever happen. /// This function exists, so that all operations are accounted for in the wrapping operations. - pub const fn wrapping_div_vartime(&self, rhs: &NonZero) -> Self { - let (q, _) = self.div_rem_vartime(rhs); - q + pub const fn wrapping_div_vartime(&self, rhs: &NonZero>) -> Self { + self.div_rem_vartime(rhs).0 } /// Perform checked division, returning a [`CtOption`] which `is_some` @@ -771,6 +973,11 @@ mod tests { &NonZero::new(U256::from(7u8)).unwrap(), ); assert_eq!(r, U256::from(3u8)); + let r = U256::rem_wide_vartime( + (U256::from(10u8), U256::ZERO), + &NonZero::new(U256::MAX).unwrap(), + ); + assert_eq!(r, U256::from(10u8)); } #[test] diff --git a/src/uint/div_limb.rs b/src/uint/div_limb.rs index f622d8999..28f876517 100644 --- a/src/uint/div_limb.rs +++ b/src/uint/div_limb.rs @@ -254,6 +254,32 @@ pub(crate) const fn rem_limb_with_reciprocal( Limb(r >> reciprocal.shift) } +/// Divides the wide `u` by the divisor encoded in the `reciprocal`, and returns the remainder. +#[inline(always)] +pub(crate) const fn rem_limb_with_reciprocal_wide( + lo_hi: (&Uint, &Uint), + reciprocal: &Reciprocal, +) -> Limb { + let (lo_shifted, carry) = lo_hi.0.shl_limb(reciprocal.shift); + let (mut hi_shifted, xhi) = lo_hi.1.shl_limb(reciprocal.shift); + hi_shifted.limbs[0].0 |= carry.0; + let mut r = xhi.0; + + let mut j = L; + while j > 0 { + j -= 1; + let (_, rj) = div2by1(r, hi_shifted.as_limbs()[j].0, reciprocal); + r = rj; + } + j = L; + while j > 0 { + j -= 1; + let (_, rj) = div2by1(r, lo_shifted.as_limbs()[j].0, reciprocal); + r = rj; + } + Limb(r >> reciprocal.shift) +} + #[cfg(test)] mod tests { use super::{div2by1, Reciprocal};