diff --git a/src/uint/modular/runtime_mod.rs b/src/uint/modular/runtime_mod.rs index 32b078599..4468d71ac 100644 --- a/src/uint/modular/runtime_mod.rs +++ b/src/uint/modular/runtime_mod.rs @@ -7,6 +7,8 @@ use super::{ Retrieve, }; +use subtle::CtOption; + /// Additions between residues with a modulus set at runtime mod runtime_add; /// Multiplicative inverses of residues with a modulus set at runtime @@ -20,7 +22,7 @@ mod runtime_pow; /// Subtractions between residues with a modulus set at runtime mod runtime_sub; -/// The parameters to efficiently go to and from the Montgomery form for a modulus provided at runtime. +/// The parameters to efficiently go to and from the Montgomery form for an odd modulus provided at runtime. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct DynResidueParams { // The constant modulus @@ -37,8 +39,8 @@ pub struct DynResidueParams { } impl DynResidueParams { - /// Instantiates a new set of `ResidueParams` representing the given `modulus`. - pub const fn new(modulus: &Uint) -> Self { + // Internal helper function to generate parameters; this lets us wrap the constructors more cleanly + const fn generate_params(modulus: &Uint) -> Self { let r = Uint::MAX.const_rem(modulus).0.wrapping_add(&Uint::ONE); let r2 = Uint::const_rem_wide(r.square_wide(), modulus).0; @@ -59,6 +61,28 @@ impl DynResidueParams { } } + /// Instantiates a new set of `ResidueParams` representing the given `modulus`, which _must_ be odd. + /// If `modulus` is not odd, this function will panic; use [`new_checked`][`DynResidueParams::new_checked`] if you want to be able to detect an invalid modulus. + pub const fn new(modulus: &Uint) -> Self { + // A valid modulus must be odd + if modulus.ct_is_odd().to_u8() == 0 { + panic!("modulus must be odd"); + } + + Self::generate_params(modulus) + } + + /// Instantiates a new set of `ResidueParams` representing the given `modulus` if it is odd. + /// Returns a `CtOption` that is `None` if the provided modulus is not odd; this is a safer version of [`new`][`DynResidueParams::new`], which can panic. + #[deprecated( + since = "0.5.3", + note = "This functionality will be moved to `new` in a future release." + )] + pub fn new_checked(modulus: &Uint) -> CtOption { + // A valid modulus must be odd, which we check in constant time + CtOption::new(Self::generate_params(modulus), modulus.ct_is_odd().into()) + } + /// Returns the modulus which was used to initialize these parameters. pub const fn modulus(&self) -> &Uint { &self.modulus @@ -194,3 +218,37 @@ impl zeroize::Zeroize for DynResidue { self.montgomery_form.zeroize() } } + +#[cfg(test)] +mod test { + use super::*; + use crate::nlimbs; + + const LIMBS: usize = nlimbs!(64); + + #[test] + #[allow(deprecated)] + // Test that a valid modulus yields `DynResidueParams` + fn test_valid_modulus() { + let valid_modulus = Uint::::from(3u8); + + DynResidueParams::::new_checked(&valid_modulus).unwrap(); + DynResidueParams::::new(&valid_modulus); + } + + #[test] + #[allow(deprecated)] + // Test that an invalid checked modulus does not yield `DynResidueParams` + fn test_invalid_checked_modulus() { + assert!(bool::from( + DynResidueParams::::new_checked(&Uint::from(2u8)).is_none() + )) + } + + #[test] + #[should_panic] + // Tets that an invalid modulus panics + fn test_invalid_modulus() { + DynResidueParams::::new(&Uint::from(2u8)); + } +}