diff --git a/der/src/asn1/set_of.rs b/der/src/asn1/set_of.rs index 14198b8da..b9314ad77 100644 --- a/der/src/asn1/set_of.rs +++ b/der/src/asn1/set_of.rs @@ -2,12 +2,12 @@ use crate::{ arrayvec, ord::iter_cmp, ArrayVec, Decodable, DecodeValue, Decoder, DerOrd, Encodable, - EncodeValue, Encoder, ErrorKind, FixedTag, Header, Length, Result, Tag, ValueOrd, + EncodeValue, Encoder, Error, ErrorKind, FixedTag, Header, Length, Result, Tag, ValueOrd, }; use core::cmp::Ordering; #[cfg(feature = "alloc")] -use {crate::Error, alloc::vec::Vec, core::slice}; +use {alloc::vec::Vec, core::slice}; /// ASN.1 `SET OF` backed by an array. /// @@ -18,14 +18,14 @@ use {crate::Error, alloc::vec::Vec, core::slice}; #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] pub struct SetOf where - T: Clone + DerOrd, + T: DerOrd, { inner: ArrayVec, } impl SetOf where - T: Clone + DerOrd, + T: DerOrd, { /// Create a new [`SetOf`]. pub fn new() -> Self { @@ -74,7 +74,7 @@ where impl Default for SetOf where - T: Clone + DerOrd, + T: DerOrd, { fn default() -> Self { Self::new() @@ -83,7 +83,7 @@ where impl<'a, T, const N: usize> DecodeValue<'a> for SetOf where - T: Clone + Decodable<'a> + DerOrd, + T: Decodable<'a> + DerOrd, { fn decode_value(decoder: &mut Decoder<'a>, header: Header) -> Result { let end_pos = (decoder.position() + header.length)?; @@ -103,7 +103,7 @@ where impl<'a, T, const N: usize> EncodeValue for SetOf where - T: 'a + Clone + Decodable<'a> + Encodable + DerOrd, + T: 'a + Decodable<'a> + Encodable + DerOrd, { fn value_len(&self) -> Result { self.iter() @@ -121,14 +121,33 @@ where impl<'a, T, const N: usize> FixedTag for SetOf where - T: Clone + Decodable<'a> + DerOrd, + T: Decodable<'a> + DerOrd, { const TAG: Tag = Tag::Set; } +impl TryFrom<[T; N]> for SetOf +where + T: DerOrd, +{ + type Error = Error; + + fn try_from(mut arr: [T; N]) -> Result> { + der_sort(&mut arr)?; + + let mut result = SetOf::new(); + + for elem in arr { + result.add(elem)?; + } + + Ok(result) + } +} + impl ValueOrd for SetOf where - T: Clone + DerOrd, + T: DerOrd, { fn value_cmp(&self, other: &Self) -> Result { iter_cmp(self.iter(), other.iter()) @@ -161,7 +180,7 @@ impl<'a, T> ExactSizeIterator for SetOfIter<'a, T> {} #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] pub struct SetOfVec where - T: Clone + DerOrd, + T: DerOrd, { inner: Vec, } @@ -170,7 +189,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl SetOfVec where - T: Clone + DerOrd, + T: DerOrd, { /// Create a new [`SetOfVec`]. pub fn new() -> Self { @@ -195,6 +214,11 @@ where Ok(()) } + /// Borrow the elements of this [`SetOfVec`] as a slice. + pub fn as_slice(&self) -> &[T] { + self.inner.as_slice() + } + /// Get the nth element from this [`SetOfVec`]. pub fn get(&self, index: usize) -> Option<&T> { self.inner.get(index) @@ -225,10 +249,10 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl AsRef<[T]> for SetOfVec where - T: Clone + DerOrd, + T: DerOrd, { fn as_ref(&self) -> &[T] { - &self.inner + self.as_slice() } } @@ -236,7 +260,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl<'a, T> DecodeValue<'a> for SetOfVec where - T: Clone + Decodable<'a> + DerOrd, + T: Decodable<'a> + DerOrd, { fn decode_value(decoder: &mut Decoder<'a>, header: Header) -> Result { let end_pos = (decoder.position() + header.length)?; @@ -258,7 +282,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl<'a, T> EncodeValue for SetOfVec where - T: 'a + Clone + Decodable<'a> + Encodable + DerOrd, + T: 'a + Decodable<'a> + Encodable + DerOrd, { fn value_len(&self) -> Result { self.iter() @@ -278,7 +302,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl FixedTag for SetOfVec where - T: Clone + DerOrd, + T: DerOrd, { const TAG: Tag = Tag::Set; } @@ -287,7 +311,7 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl From> for Vec where - T: Clone + DerOrd, + T: DerOrd, { fn from(set: SetOfVec) -> Vec { set.into_vec() @@ -298,46 +322,84 @@ where #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl TryFrom> for SetOfVec where - T: Clone + DerOrd, + T: DerOrd, { type Error = Error; fn try_from(mut vec: Vec) -> Result> { - if vec.len() > 1 { - // Use `Ordering::Less` as a placeholder in the event of comparison failure - vec.sort_by(|a, b| a.der_cmp(b).unwrap_or(Ordering::Less)); - - // Perform a pass over the elements to ensure they're sorted - for i in 0..(vec.len() - 1) { - match vec.get(i..(i + 2)) { - Some([a, b]) => match a.der_cmp(b) { - Ok(Ordering::Less) | Ok(Ordering::Equal) => (), - _ => return Err(ErrorKind::SetOrdering.into()), - }, - _ => return Err(ErrorKind::SetOrdering.into()), - } - } - } - + // TODO(tarcieri): use `[T]::sort_by` here? + der_sort(vec.as_mut_slice())?; Ok(SetOfVec { inner: vec }) } } +#[cfg(feature = "alloc")] +#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] +impl TryFrom<[T; N]> for SetOfVec +where + T: DerOrd, +{ + type Error = Error; + + fn try_from(arr: [T; N]) -> Result> { + Vec::from(arr).try_into() + } +} + #[cfg(feature = "alloc")] #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] impl ValueOrd for SetOfVec where - T: Clone + DerOrd, + T: DerOrd, { fn value_cmp(&self, other: &Self) -> Result { iter_cmp(self.iter(), other.iter()) } } -#[cfg(test)] +/// Sort a mut slice according to its [`DerOrd`], returning any errors which +/// might occur during the comparison. +/// +/// The algorithm is insertion sort, which should perform well when the input +/// is mostly sorted to begin with. +/// +/// This function is used rather than Rust's built-in `[T]::sort_by` in order +/// to support heapless `no_std` targets as well as to enable bubbling up +/// sorting errors. +fn der_sort(slice: &mut [T]) -> Result<()> { + for i in 1..=slice.len() { + let mut j = i - 1; + + while j > 0 && slice[j - 1].der_cmp(&slice[j])? == Ordering::Greater { + slice.swap(j - 1, j); + j -= 1; + } + } + + Ok(()) +} + +#[cfg(all(test, feature = "alloc"))] mod tests { - #[cfg(feature = "alloc")] - use super::SetOfVec; + use super::{SetOf, SetOfVec}; + use alloc::vec::Vec; + + #[test] + fn setof_tryfrom_array() { + let arr = [3u16, 2, 1, 65535, 0]; + let set = SetOf::try_from(arr).unwrap(); + assert_eq!( + set.iter().cloned().collect::>(), + &[0, 1, 2, 3, 65535] + ); + } + + #[test] + fn setofvec_tryfrom_array() { + let arr = [3u16, 2, 1, 65535, 0]; + let set = SetOfVec::try_from(arr).unwrap(); + assert_eq!(set.as_ref(), &[0, 1, 2, 3, 65535]); + } #[cfg(feature = "alloc")] #[test] diff --git a/der/tests/set_of.rs b/der/tests/set_of.rs index d182e2153..177b49bd3 100644 --- a/der/tests/set_of.rs +++ b/der/tests/set_of.rs @@ -1,41 +1,61 @@ //! `SetOf` tests. -#![cfg(all(feature = "derive", feature = "oid"))] - -use core::cmp::Ordering; -use der::{ - asn1::{Any, ObjectIdentifier, SetOf}, - Decodable, Result, Sequence, ValueOrd, +#[cfg(feature = "alloc")] +use { + der::{asn1::SetOfVec, DerOrd}, + proptest::{prelude::*, string::*}, }; -use hex_literal::hex; - -/// Attribute type/value pairs as defined in [RFC 5280 Section 4.1.2.4]. -/// -/// [RFC 5280 Section 4.1.2.4]: https://tools.ietf.org/html/rfc5280#section-4.1.2.4 -#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Sequence)] -pub struct AttributeTypeAndValue<'a> { - /// OID describing the type of the attribute - pub oid: ObjectIdentifier, - - /// Value of the attribute - pub value: Any<'a>, -} -impl ValueOrd for AttributeTypeAndValue<'_> { - fn value_cmp(&self, other: &Self) -> Result { - match self.oid.value_cmp(&other.oid)? { - Ordering::Equal => self.value.value_cmp(&other.value), - other => Ok(other), - } +#[cfg(feature = "alloc")] +proptest! { + #[test] + fn sort_equiv(bytes in bytes_regex(".{0,64}").unwrap()) { + let mut expected = bytes.clone(); + expected.sort_by(|a, b| a.der_cmp(b).unwrap()); + + let set = SetOfVec::try_from(bytes).unwrap(); + prop_assert_eq!(expected.as_slice(), set.as_slice()); } } -/// Test to ensure ordering is handled correctly. -#[test] -fn ordering_regression() { - let der_bytes = hex!("3139301906035504030C12546573742055736572393031353734333830301C060A0992268993F22C640101130E3437303031303030303134373333"); - let setof = SetOf::, 3>::from_der(&der_bytes).unwrap(); +#[cfg(all(feature = "derive", feature = "oid"))] +mod attr_regression { + #![cfg(all(feature = "derive", feature = "oid"))] + + use core::cmp::Ordering; + use der::{ + asn1::{Any, ObjectIdentifier, SetOf}, + Decodable, Result, Sequence, ValueOrd, + }; + use hex_literal::hex; + + /// Attribute type/value pairs as defined in [RFC 5280 Section 4.1.2.4]. + /// + /// [RFC 5280 Section 4.1.2.4]: https://tools.ietf.org/html/rfc5280#section-4.1.2.4 + #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Sequence)] + pub struct AttributeTypeAndValue<'a> { + /// OID describing the type of the attribute + pub oid: ObjectIdentifier, - let attr1 = setof.get(0).unwrap(); - assert_eq!(ObjectIdentifier::new("2.5.4.3"), attr1.oid); + /// Value of the attribute + pub value: Any<'a>, + } + + impl ValueOrd for AttributeTypeAndValue<'_> { + fn value_cmp(&self, other: &Self) -> Result { + match self.oid.value_cmp(&other.oid)? { + Ordering::Equal => self.value.value_cmp(&other.value), + other => Ok(other), + } + } + } + + /// Test to ensure ordering is handled correctly. + #[test] + fn ordering_regression() { + let der_bytes = hex!("3139301906035504030C12546573742055736572393031353734333830301C060A0992268993F22C640101130E3437303031303030303134373333"); + let set = SetOf::, 3>::from_der(&der_bytes).unwrap(); + let attr1 = set.get(0).unwrap(); + assert_eq!(ObjectIdentifier::new("2.5.4.3"), attr1.oid); + } }