diff --git a/ml-dsa/src/crypto.rs b/ml-dsa/src/crypto.rs index 8e938e6e..cc41cec7 100644 --- a/ml-dsa/src/crypto.rs +++ b/ml-dsa/src/crypto.rs @@ -18,8 +18,11 @@ impl Default for ShakeState { } impl ShakeState { - pub fn pre_digest(digest: Shake) -> Self { - Self::Absorbing(digest) + pub fn updatable(&mut self) -> &mut Shake { + match self { + Self::Absorbing(sponge) => sponge, + Self::Squeezing(_) => unreachable!(), + } } pub fn absorb(mut self, input: &[u8]) -> Self { diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index 8a4c529c..75b2a07b 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -51,11 +51,10 @@ use hybrid_array::{ U75, U80, U88, Unsigned, }, }; -use signature::digest::Update; use signature::{DigestSigner, DigestVerifier, MultipartSigner, MultipartVerifier, Signer}; #[cfg(feature = "rand_core")] -use signature::RandomizedDigestSigner; +use signature::{RandomizedDigestSigner, RandomizedMultipartSigner, RandomizedSigner}; #[cfg(feature = "rand_core")] use rand_core::{CryptoRng, TryCryptoRng}; @@ -171,17 +170,46 @@ where const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = P::ALGORITHM_IDENTIFIER; } -// This method takes a slice of slices so that we can accommodate the varying calculations (direct -// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without -// having to allocate memory for components. -fn message_representative(tr: &[u8], Mp: &[&[&[u8]]]) -> B64 { - let mut h = H::default().absorb(tr); +struct MuBuilder(H); - for m in Mp.iter().copied().flatten() { - h = h.absorb(m); +impl MuBuilder { + fn new(tr: &[u8], ctx: &[u8]) -> Self { + let mut h = H::default(); + h = h.absorb(tr); + h = h.absorb(&[0]); + h = h.absorb(&[Truncate::truncate(ctx.len())]); + h = h.absorb(ctx); + + Self(h) + } + + fn internal(tr: &[u8], Mp: &[&[u8]]) -> B64 { + let mut h = H::default().absorb(tr); + + for m in Mp { + h = h.absorb(m); + } + + h.squeeze_new() + } + + fn message(mut self, M: &[&[u8]]) -> B64 { + for m in M { + self.0 = self.0.absorb(m); + } + + self.0.squeeze_new() + } + + fn finish(mut self) -> B64 { + self.0.squeeze_new() } +} - h.squeeze_new() +impl AsMut for MuBuilder { + fn as_mut(&mut self) -> &mut Shake256 { + self.0.updatable() + } } /// An ML-DSA key pair @@ -388,18 +416,7 @@ impl SigningKey

{ where P: MlDsaParams, { - self.raw_sign_internal(&[Mp], rnd) - } - - fn raw_sign_internal(&self, Mp: &[&[&[u8]]], rnd: &B32) -> Signature

- where - P: MlDsaParams, - { - // Compute the message representative - // XXX(RLB): This line incorporates some of the logic from ML-DSA.sign to avoid computing - // the concatenated M'. - // XXX(RLB) Should the API represent this as an input? - let mu = message_representative(&self.tr, Mp); + let mu = MuBuilder::internal(&self.tr, Mp); self.raw_sign_mu(&mu, rnd) } @@ -469,6 +486,16 @@ impl SigningKey

{ M: &[u8], ctx: &[u8], rng: &mut R, + ) -> Result, Error> { + self.raw_sign_randomized(&[M], ctx, rng) + } + + #[cfg(feature = "rand_core")] + fn raw_sign_randomized( + &self, + Mp: &[&[u8]], + ctx: &[u8], + rng: &mut R, ) -> Result, Error> { if ctx.len() > 255 { return Err(Error::new()); @@ -477,8 +504,8 @@ impl SigningKey

{ let mut rnd = B32::default(); rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?; - let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M]; - Ok(self.sign_internal(Mp, &rnd)) + let mu = MuBuilder::new(&self.tr, ctx).message(Mp); + Ok(self.raw_sign_mu(&mu, &rnd)) } /// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ. @@ -517,14 +544,13 @@ impl SigningKey

{ self.raw_sign_mu(mu, &rnd) } - fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result, Error> { + fn raw_sign_deterministic(&self, Mp: &[&[u8]], ctx: &[u8]) -> Result, Error> { if ctx.len() > 255 { return Err(Error::new()); } - let rnd = B32::default(); - let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M]; - Ok(self.raw_sign_internal(Mp, &rnd)) + let mu = MuBuilder::new(&self.tr, ctx).message(Mp); + Ok(self.sign_mu_deterministic(&mu)) } /// Encode the key in a fixed-size byte array. @@ -608,9 +634,9 @@ impl DigestSigner> for SigningKey

{ &self, f: F, ) -> Result, Error> { - let mut digest = Shake256::default().chain(self.tr).chain([0, 0]); - f(&mut digest)?; - let mu = H::pre_digest(digest).squeeze_new(); + let mut mu = MuBuilder::new(&self.tr, &[]); + f(mu.as_mut())?; + let mu = mu.finish(); Ok(self.sign_mu_deterministic(&mu)) } @@ -640,13 +666,27 @@ impl signature::Keypair for SigningKey

{ /// context string. If you would like to include a context string, use the /// [`SigningKey::sign_randomized`] method. #[cfg(feature = "rand_core")] -impl signature::RandomizedSigner> for SigningKey

{ +impl RandomizedSigner> for SigningKey

{ fn try_sign_with_rng( &self, rng: &mut R, msg: &[u8], ) -> Result, Error> { - self.sign_randomized(msg, &[], rng) + self.try_multipart_sign_with_rng(rng, &[msg]) + } +} + +/// The `RandomizedSigner` implementation for `SigningKey` only supports signing with an empty +/// context string. If you would like to include a context string, use the +/// [`SigningKey::sign_randomized`] method. +#[cfg(feature = "rand_core")] +impl RandomizedMultipartSigner> for SigningKey

{ + fn try_multipart_sign_with_rng( + &self, + rng: &mut R, + msg: &[&[u8]], + ) -> Result, Error> { + self.raw_sign_randomized(msg, &[], rng) } } @@ -663,9 +703,9 @@ impl RandomizedDigestSigner> for SigningK rng: &mut R, f: F, ) -> Result, Error> { - let mut digest = Shake256::default().chain(self.tr).chain([0, 0]); - f(&mut digest)?; - let mu = H::pre_digest(digest).squeeze_new(); + let mut mu = MuBuilder::new(&self.tr, &[]); + f(mu.as_mut())?; + let mu = mu.finish(); self.sign_mu_randomized(&mu, rng) } @@ -736,19 +776,11 @@ impl VerifyingKey

{ /// include the domain separator that distinguishes between the normal and pre-hashed cases, /// and it does not separate the context string from the rest of the message. // Algorithm 8 ML-DSA.Verify_internal - pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature

) -> bool - where - P: MlDsaParams, - { - self.raw_verify_internal(&[Mp], sigma) - } - - fn raw_verify_internal(&self, Mp: &[&[&[u8]]], sigma: &Signature

) -> bool + pub fn verify_internal(&self, M: &[u8], sigma: &Signature

) -> bool where P: MlDsaParams, { - // Compute the message representative - let mu = message_representative(&self.tr, Mp); + let mu = MuBuilder::internal(&self.tr, &[M]); self.raw_verify_mu(&mu, sigma) } @@ -793,8 +825,8 @@ impl VerifyingKey

{ return false; } - let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M]; - self.raw_verify_internal(Mp, sigma) + let mu = MuBuilder::new(&self.tr, ctx).message(M); + self.verify_mu(&mu, sigma) } fn encode_internal(rho: &B32, t1: &Vector) -> EncodedVerifyingKey

{ @@ -837,9 +869,9 @@ impl DigestVerifier> for VerifyingKey

f: F, signature: &Signature

, ) -> Result<(), Error> { - let mut digest = Shake256::default().chain(self.tr).chain([0, 0]); - f(&mut digest)?; - let mu = H::pre_digest(digest).squeeze_new(); + let mut mu = MuBuilder::new(&self.tr, &[]); + f(mu.as_mut())?; + let mu = mu.finish(); self.raw_verify_mu(&mu, signature) .then_some(()) @@ -1060,6 +1092,7 @@ where mod test { use super::*; use crate::param::*; + use signature::digest::Update; #[test] fn output_sizes() { @@ -1142,7 +1175,7 @@ mod test { let rnd = Array([0u8; 32]); let sig = sk.sign_internal(&[M], &rnd); - assert!(vk.verify_internal(&[M], &sig)); + assert!(vk.verify_internal(M, &sig)); } #[test] @@ -1179,7 +1212,7 @@ mod test { let sig_dec = Signature::

::decode(&sig_enc).unwrap(); assert_eq!(sig_dec, sig); - assert!(vk.verify_internal(&[M], &sig_dec)); + assert!(vk.verify_internal(M, &sig_dec)); } } @@ -1202,7 +1235,7 @@ mod test { let M = b"Hello world"; let rnd = Array([0u8; 32]); - let mu = message_representative(&sk.tr, &[&[M]]); + let mu = MuBuilder::internal(&sk.tr, &[M]); let sig = sk.raw_sign_mu(&mu, &rnd); assert!(vk.raw_verify_mu(&mu, &sig)); @@ -1224,10 +1257,10 @@ mod test { let M = b"Hello world"; let rnd = Array([0u8; 32]); - let mu = message_representative(&sk.tr, &[&[M]]); + let mu = MuBuilder::internal(&sk.tr, &[M]); let sig = sk.raw_sign_mu(&mu, &rnd); - assert!(vk.verify_internal(&[M], &sig)); + assert!(vk.verify_internal(M, &sig)); } sign_mu_verify_internal::(); sign_mu_verify_internal::(); @@ -1246,7 +1279,7 @@ mod test { let M = b"Hello world"; let rnd = Array([0u8; 32]); - let mu = message_representative(&sk.tr, &[&[M]]); + let mu = MuBuilder::internal(&sk.tr, &[M]); let sig = sk.sign_internal(&[M], &rnd); assert!(vk.raw_verify_mu(&mu, &sig)); diff --git a/ml-dsa/tests/sig-ver.rs b/ml-dsa/tests/sig-ver.rs index 2c783e78..b5e7505d 100644 --- a/ml-dsa/tests/sig-ver.rs +++ b/ml-dsa/tests/sig-ver.rs @@ -35,7 +35,7 @@ fn verify(tg: &acvp::TestGroup, tc: &acvp::TestCase) { // Verify the signature if it successfully decoded let test_passed = sig - .map(|sig| vk.verify_internal(&[&tc.message], &sig)) + .map(|sig| vk.verify_internal(&tc.message, &sig)) .unwrap_or_default(); assert_eq!(test_passed, tc.test_passed); }