From 92d2a8211581610191ee489372231aa644760c1a Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 11 Feb 2025 20:09:33 -0500 Subject: [PATCH 1/5] Make ML-DSA signature decoding follow the spec --- ml-dsa/src/hint.rs | 5 ++-- ml-dsa/src/lib.rs | 62 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/ml-dsa/src/hint.rs b/ml-dsa/src/hint.rs index df9f321e..83370e1d 100644 --- a/ml-dsa/src/hint.rs +++ b/ml-dsa/src/hint.rs @@ -33,7 +33,7 @@ fn use_hint(h: bool, r: Elem) -> Elem { } } -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Hint

(pub Array, P::K>) where P: SignatureParams; @@ -116,7 +116,7 @@ where } fn monotonic(a: &[usize]) -> bool { - a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] < *x) + a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] <= *x) } pub fn bit_unpack(y: &EncodedHint

) -> Option { @@ -138,6 +138,7 @@ where let indices = &indices[start..end]; if !Self::monotonic(indices) { + println!("indices not monotonic: {:?}", indices); return None; } diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index 3efe91cf..8ed841e4 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -1,4 +1,4 @@ -#![no_std] +// XXX #![no_std] #![doc = include_str!("../README.md")] #![doc( html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg", @@ -89,7 +89,7 @@ pub use crate::util::B32; pub use signature::Error; /// An ML-DSA signature -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Signature { c_tilde: Array, z: Vector, @@ -899,4 +899,62 @@ mod test { sign_verify_round_trip_test::(); sign_verify_round_trip_test::(); } + + fn many_round_trip_test

() + where + P: MlDsaParams, + { + use rand::Rng; + + const ITERATIONS: usize = 1000; + + let mut rng = rand::thread_rng(); + let mut seed = B32::default(); + + for _i in 0..ITERATIONS { + let seed_data: &mut [u8] = seed.as_mut(); + rng.fill(seed_data); + + let kp = P::key_gen_internal(&seed); + let sk = kp.signing_key; + let vk = kp.verifying_key; + + let M = b"Hello world"; + let rnd = Array([0u8; 32]); + let sig = sk.sign_internal(&[M], &rnd); + + let sig_enc = sig.encode(); + let sig_dec = Signature::

::decode(&sig_enc).unwrap(); + + assert_eq!(sig_dec, sig); + assert!(vk.verify_internal(&[M], &sig)); + } + } + + #[test] + fn many_round_trip() { + many_round_trip_test::(); + many_round_trip_test::(); + many_round_trip_test::(); + } + + #[test] + fn encode_decode_fail() { + use signature::Signer; + + const SEED: [u8; 32] = [ + 197, 185, 159, 59, 216, 233, 208, 40, 244, 4, 182, 73, 109, 244, 205, 113, 116, 55, + 206, 145, 214, 205, 247, 130, 41, 113, 93, 14, 140, 194, 191, 232, + ]; + const MESSAGE: &[u8] = b"There seems to be a round tripping issue somewhere in here"; + + let mut seed = B32::default(); + seed.0.copy_from_slice(&SEED); + + let seed = SEED.into(); + let kp = MlDsa65::key_gen_internal(&seed); + let sig = kp.signing_key().sign(MESSAGE); + let sig_enc = sig.encode(); + let _sig = Signature::::decode(&sig_enc).unwrap(); + } } From c75ac84fcb39920e8c7cf8c7a25c025461ab3bb4 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 11 Feb 2025 20:12:36 -0500 Subject: [PATCH 2/5] Remove extra test --- ml-dsa/src/lib.rs | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index 8ed841e4..bd15aca0 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -937,24 +937,4 @@ mod test { many_round_trip_test::(); many_round_trip_test::(); } - - #[test] - fn encode_decode_fail() { - use signature::Signer; - - const SEED: [u8; 32] = [ - 197, 185, 159, 59, 216, 233, 208, 40, 244, 4, 182, 73, 109, 244, 205, 113, 116, 55, - 206, 145, 214, 205, 247, 130, 41, 113, 93, 14, 140, 194, 191, 232, - ]; - const MESSAGE: &[u8] = b"There seems to be a round tripping issue somewhere in here"; - - let mut seed = B32::default(); - seed.0.copy_from_slice(&SEED); - - let seed = SEED.into(); - let kp = MlDsa65::key_gen_internal(&seed); - let sig = kp.signing_key().sign(MESSAGE); - let sig_enc = sig.encode(); - let _sig = Signature::::decode(&sig_enc).unwrap(); - } } From f1e4e305680c7bc0c22ee7f663d2ac3285325794 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 12 Feb 2025 10:07:32 -0500 Subject: [PATCH 3/5] Cleanup --- ml-dsa/src/hint.rs | 3 +-- ml-dsa/src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ml-dsa/src/hint.rs b/ml-dsa/src/hint.rs index 83370e1d..c68399d0 100644 --- a/ml-dsa/src/hint.rs +++ b/ml-dsa/src/hint.rs @@ -26,7 +26,7 @@ fn use_hint(h: bool, r: Elem) -> Elem { Elem::new((r1.0 + m - 1) % m) } else if h { // We use the Elem encoding even for signed integers. Since r0 is computed - // mod+- 2*gamma2, it is guaranteed to be in (gamma2, gamma2]. + // mod+- 2*gamma2, it is guaranteed to be in (-gamma2, gamma2]. unreachable!(); } else { r1 @@ -138,7 +138,6 @@ where let indices = &indices[start..end]; if !Self::monotonic(indices) { - println!("indices not monotonic: {:?}", indices); return None; } diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index bd15aca0..124ccbb7 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -1,4 +1,4 @@ -// XXX #![no_std] +#![no_std] #![doc = include_str!("../README.md")] #![doc( html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo.svg", From 1b794e1b900bfb9e7793d9eab219ddbfda09e642 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 12 Feb 2025 14:55:53 -0500 Subject: [PATCH 4/5] Allow for the full range of decompose() outputs --- ml-dsa/src/algebra.rs | 41 +++++++++++++++++++++++++++++++++++++++++ ml-dsa/src/hint.rs | 4 ++-- ml-dsa/src/lib.rs | 2 +- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/ml-dsa/src/algebra.rs b/ml-dsa/src/algebra.rs index b89bd283..ad7caea9 100644 --- a/ml-dsa/src/algebra.rs +++ b/ml-dsa/src/algebra.rs @@ -210,3 +210,44 @@ impl AlgebraExt for Vector { ) } } + +#[cfg(test)] +mod test { + use super::*; + + use crate::{MlDsa65, ParameterSet}; + + type TwoGamma2 = ::TwoGamma2; + const TWO_GAMMA_2: u32 = TwoGamma2::U32; + + #[test] + fn mod_plus_minus() { + for x in 0..BaseField::Q { + let x = Elem::new(x); + let x0 = x.mod_plus_minus::(); + + // Outputs from mod+- should be in the half-open interval (-gamma2, gamma2] + let positive_bound = x0.0 <= TWO_GAMMA_2 / 2; + let negative_bound = x0.0 > BaseField::Q - TWO_GAMMA_2 / 2; + assert!(positive_bound || negative_bound); + } + } + + #[test] + fn decompose() { + for x in 0..BaseField::Q { + let x = Elem::new(x); + let (x1, x0) = x.decompose::(); + + // The low-order output from decompose() is a mod+- output, optionally minus one. So + // they should be in the closed interval [-gamma2, gamma2]. + let positive_bound = x0.0 <= TWO_GAMMA_2 / 2; + let negative_bound = x0.0 >= BaseField::Q - TWO_GAMMA_2 / 2; + assert!(positive_bound || negative_bound); + + // The low-order and high-order values + let xx = (TWO_GAMMA_2 * x1.0 + x0.0) % BaseField::Q; + assert_eq!(xx, x.0); + } + } +} diff --git a/ml-dsa/src/hint.rs b/ml-dsa/src/hint.rs index c68399d0..ced3cf34 100644 --- a/ml-dsa/src/hint.rs +++ b/ml-dsa/src/hint.rs @@ -22,11 +22,11 @@ fn use_hint(h: bool, r: Elem) -> Elem { let gamma2 = TwoGamma2::U32 / 2; if h && r0.0 <= gamma2 { Elem::new((r1.0 + 1) % m) - } else if h && r0.0 > BaseField::Q - gamma2 { + } else if h && r0.0 >= BaseField::Q - gamma2 { Elem::new((r1.0 + m - 1) % m) } else if h { // We use the Elem encoding even for signed integers. Since r0 is computed - // mod+- 2*gamma2, it is guaranteed to be in (-gamma2, gamma2]. + // mod+- 2*gamma2 (possibly minus 1), it is guaranteed to be in [-gamma2, gamma2]. unreachable!(); } else { r1 diff --git a/ml-dsa/src/lib.rs b/ml-dsa/src/lib.rs index 124ccbb7..96d856bf 100644 --- a/ml-dsa/src/lib.rs +++ b/ml-dsa/src/lib.rs @@ -927,7 +927,7 @@ mod test { let sig_dec = Signature::

::decode(&sig_enc).unwrap(); assert_eq!(sig_dec, sig); - assert!(vk.verify_internal(&[M], &sig)); + assert!(vk.verify_internal(&[M], &sig_dec)); } } From 17e9cab3531e4ea795ff160c6b25895388ff3004 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 12 Feb 2025 15:17:05 -0500 Subject: [PATCH 5/5] Cleanup --- ml-dsa/src/algebra.rs | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/ml-dsa/src/algebra.rs b/ml-dsa/src/algebra.rs index ad7caea9..85de202e 100644 --- a/ml-dsa/src/algebra.rs +++ b/ml-dsa/src/algebra.rs @@ -217,36 +217,45 @@ mod test { use crate::{MlDsa65, ParameterSet}; - type TwoGamma2 = ::TwoGamma2; - const TWO_GAMMA_2: u32 = TwoGamma2::U32; + type Mod = ::TwoGamma2; + const MOD: u32 = Mod::U32; + const MOD_ELEM: Elem = Elem::new(MOD); #[test] fn mod_plus_minus() { - for x in 0..BaseField::Q { + for x in 0..MOD { + // BaseField::Q { let x = Elem::new(x); - let x0 = x.mod_plus_minus::(); + let x0 = x.mod_plus_minus::(); // Outputs from mod+- should be in the half-open interval (-gamma2, gamma2] - let positive_bound = x0.0 <= TWO_GAMMA_2 / 2; - let negative_bound = x0.0 > BaseField::Q - TWO_GAMMA_2 / 2; + let positive_bound = x0.0 <= MOD / 2; + let negative_bound = x0.0 > BaseField::Q - MOD / 2; assert!(positive_bound || negative_bound); + + // The output should be equivalent to the input, mod 2 * gamma2. We add 2 * gamma2 + // before comparing so that both values are "positive", avoiding interactions between + // the mod-Q and mod-M operations. + let xn = x + MOD_ELEM; + let x0n = x0 + MOD_ELEM; + assert_eq!(xn.0 % MOD, x0n.0 % MOD); } } #[test] fn decompose() { - for x in 0..BaseField::Q { + for x in 0..MOD { let x = Elem::new(x); - let (x1, x0) = x.decompose::(); + let (x1, x0) = x.decompose::(); // The low-order output from decompose() is a mod+- output, optionally minus one. So // they should be in the closed interval [-gamma2, gamma2]. - let positive_bound = x0.0 <= TWO_GAMMA_2 / 2; - let negative_bound = x0.0 >= BaseField::Q - TWO_GAMMA_2 / 2; + let positive_bound = x0.0 <= MOD / 2; + let negative_bound = x0.0 >= BaseField::Q - MOD / 2; assert!(positive_bound || negative_bound); - // The low-order and high-order values - let xx = (TWO_GAMMA_2 * x1.0 + x0.0) % BaseField::Q; + // The low-order and high-order outputs should combine to form the input. + let xx = (MOD * x1.0 + x0.0) % BaseField::Q; assert_eq!(xx, x.0); } }