diff --git a/der/src/asn1/bit_string.rs b/der/src/asn1/bit_string.rs index 146541bf6..31ac34c95 100644 --- a/der/src/asn1/bit_string.rs +++ b/der/src/asn1/bit_string.rs @@ -1,6 +1,6 @@ //! ASN.1 `BIT STRING` support. -pub mod fixed_len_bit_string; +pub mod allowed_len_bit_string; use crate::{ BytesRef, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, FixedTag, Header, Length, Reader, diff --git a/der/src/asn1/bit_string/fixed_len_bit_string.rs b/der/src/asn1/bit_string/allowed_len_bit_string.rs similarity index 93% rename from der/src/asn1/bit_string/fixed_len_bit_string.rs rename to der/src/asn1/bit_string/allowed_len_bit_string.rs index 3d12d0d6b..ee861046e 100644 --- a/der/src/asn1/bit_string/fixed_len_bit_string.rs +++ b/der/src/asn1/bit_string/allowed_len_bit_string.rs @@ -27,7 +27,7 @@ use crate::{Error, ErrorKind, Tag}; /// flag4: bool, /// } /// ``` -pub trait FixedLenBitString { +pub trait AllowedLenBitString { /// Implementer must specify how many bits are allowed const ALLOWED_LEN_RANGE: RangeInclusive; @@ -35,7 +35,7 @@ pub trait FixedLenBitString { fn check_bit_len(bit_len: u16) -> Result<(), Error> { let allowed_len_range = Self::ALLOWED_LEN_RANGE; - // forces allowed range to eg. 3..=4 + // forces allowed range to e.g. 3..=4 if !allowed_len_range.contains(&bit_len) { Err(ErrorKind::Length { tag: Tag::BitString, diff --git a/der/src/lib.rs b/der/src/lib.rs index bb1e7f020..ef18ead82 100644 --- a/der/src/lib.rs +++ b/der/src/lib.rs @@ -365,7 +365,7 @@ mod document; mod str_owned; pub use crate::{ - asn1::bit_string::fixed_len_bit_string::FixedLenBitString, + asn1::bit_string::allowed_len_bit_string::AllowedLenBitString, asn1::{AnyRef, Choice, Sequence}, datetime::DateTime, decode::{Decode, DecodeOwned, DecodeValue}, diff --git a/der/tests/derive.rs b/der/tests/derive.rs index 0624ba21f..98dd25313 100644 --- a/der/tests/derive.rs +++ b/der/tests/derive.rs @@ -699,7 +699,9 @@ mod bitstring { assert_eq!(reencoded, BITSTRING_EXAMPLE); } - /// this BitString will allow only 3..=4 bits + /// this BitString will allow only 3..=4 bits in Decode + /// + /// but will always Encode 4 bits #[derive(BitString)] pub struct MyBitString3or4 { pub bit_0: bool, @@ -765,8 +767,8 @@ mod bitstring { .to_der() .unwrap(); - // 3 bits used, 5 unused - assert_eq!(encoded_3_zeros, hex!("03 02 05 00")); + // 4 bits used, 4 unused + assert_eq!(encoded_3_zeros, hex!("03 02 04 00")); } #[test] @@ -780,8 +782,8 @@ mod bitstring { .to_der() .unwrap(); - // 3 bits used, 5 unused - assert_eq!(encoded_3_zeros, hex!("03 02 05 E0")); + // 4 bits used, 4 unused + assert_eq!(encoded_3_zeros, hex!("03 02 04 E0")); } #[test] @@ -813,6 +815,107 @@ mod bitstring { // 4 bits used, 4 unused assert_eq!(encoded_4_zeros, hex!("03 02 04 10")); } + + /// ```asn1 + /// PasswordFlags ::= BIT STRING { + /// case-sensitive (0), + /// local (1), + /// change-disabled (2), + /// unblock-disabled (3), + /// initialized (4), + /// needs-padding (5), + /// unblockingPassword (6), + /// soPassword (7), + /// disable-allowed (8), + /// integrity-protected (9), + /// confidentiality-protected (10), + /// exchangeRefData (11), + /// resetRetryCounter1 (12), + /// resetRetryCounter2 (13), + /// context-dependent (14), + /// multiStepProtocol (15) + /// } + /// ``` + #[derive(Clone, Debug, Eq, PartialEq, BitString)] + pub struct PasswordFlags { + /// case-sensitive (0) + pub case_sensitive: bool, + + /// local (1) + pub local: bool, + + /// change-disabled (2) + pub change_disabled: bool, + + /// unblock-disabled (3) + pub unblock_disabled: bool, + + /// initialized (4) + pub initialized: bool, + + /// needs-padding (5) + pub needs_padding: bool, + + /// unblockingPassword (6) + pub unblocking_password: bool, + + /// soPassword (7) + pub so_password: bool, + + /// disable-allowed (8) + pub disable_allowed: bool, + + /// integrity-protected (9) + pub integrity_protected: bool, + + /// confidentiality-protected (10) + pub confidentiality_protected: bool, + + /// exchangeRefData (11) + pub exchange_ref_data: bool, + + /// Second edition 2016-05-15 + /// resetRetryCounter1 (12) + #[asn1(optional = "true")] + pub reset_retry_counter1: bool, + + /// resetRetryCounter2 (13) + #[asn1(optional = "true")] + pub reset_retry_counter2: bool, + + /// context-dependent (14) + #[asn1(optional = "true")] + pub context_dependent: bool, + + /// multiStepProtocol (15) + #[asn1(optional = "true")] + pub multi_step_protocol: bool, + + /// fake_bit_for_testing (16) + #[asn1(optional = "true")] + pub fake_bit_for_testing: bool, + } + + const PASS_FLAGS_EXAMPLE_IN: &[u8] = &hex!("03 03 04 FF FF"); + const PASS_FLAGS_EXAMPLE_OUT: &[u8] = &hex!("03 04 07 FF F0 00"); + + #[test] + fn decode_short_bitstring_2_bytes() { + let pass_flags = PasswordFlags::from_der(PASS_FLAGS_EXAMPLE_IN).unwrap(); + + // case-sensitive (0) + assert!(pass_flags.case_sensitive); + + // exchangeRefData (11) + assert!(pass_flags.exchange_ref_data); + + // resetRetryCounter1 (12) + assert!(!pass_flags.reset_retry_counter1); + + let reencoded = pass_flags.to_der().unwrap(); + + assert_eq!(reencoded, PASS_FLAGS_EXAMPLE_OUT); + } } mod infer_default { //! When another crate might define a PartialEq for another type, the use of diff --git a/der_derive/src/bitstring.rs b/der_derive/src/bitstring.rs index 5cdcb20b3..508069fa3 100644 --- a/der_derive/src/bitstring.rs +++ b/der_derive/src/bitstring.rs @@ -35,12 +35,26 @@ impl DeriveBitString { let type_attrs = TypeAttrs::parse(&input.attrs)?; - let fields = data + let fields: Vec<_> = data .fields .iter() .map(|field| BitStringField::new(field, &type_attrs)) .collect::>()?; + let mut started_optionals = false; + for field in &fields { + if !field.attrs.optional { + if started_optionals { + abort!( + input.ident, + "derive `BitString` only supports trailing optional fields one after another", + ) + } + } else { + started_optionals = true; + } + } + Ok(Self { ident: input.ident, generics: input.generics.clone(), @@ -75,14 +89,18 @@ impl DeriveBitString { let mut min_expected_fields: u16 = 0; let mut max_expected_fields: u16 = 0; + let mut started_optionals = false; for field in &self.fields { max_expected_fields += 1; - if !field.attrs.optional { + if field.attrs.optional { + started_optionals = true; + } + if !started_optionals { min_expected_fields += 1; } } - let min_expected_bytes = (min_expected_fields + 7) / 8; + let max_expected_bytes = (max_expected_fields + 7) / 8; for (i, field) in self.fields.iter().enumerate().rev() { let field_name = &field.ident; @@ -115,7 +133,7 @@ impl DeriveBitString { impl ::der::FixedTag for #ident #ty_generics #where_clause { const TAG: der::Tag = ::der::Tag::BitString; } - impl ::der::FixedLenBitString for #ident #ty_generics #where_clause { + impl ::der::AllowedLenBitString for #ident #ty_generics #where_clause { const ALLOWED_LEN_RANGE: ::core::ops::RangeInclusive = #min_expected_fields..=#max_expected_fields; } @@ -127,7 +145,7 @@ impl DeriveBitString { header: ::der::Header, ) -> ::core::result::Result { use ::der::{Decode as _, DecodeValue as _, Reader as _}; - use ::der::FixedLenBitString as _; + use ::der::AllowedLenBitString as _; let bs = ::der::asn1::BitStringRef::decode_value(reader, header)?; @@ -147,33 +165,15 @@ impl DeriveBitString { impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause { fn value_len(&self) -> der::Result { - Ok(der::Length::new(#min_expected_bytes + 1)) + Ok(der::Length::new(#max_expected_bytes + 1)) } fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> { use ::der::Encode as _; - use der::FixedLenBitString as _; + use ::der::AllowedLenBitString as _; let arr = [#(#encode_bytes),*]; - - let min_bits = { - let max_bits = *Self::ALLOWED_LEN_RANGE.end(); - let last_byte_bits = (max_bits % 8) as u8; - let bs = ::der::asn1::BitStringRef::new(8 - last_byte_bits, &arr)?; - - let mut min_bits = *Self::ALLOWED_LEN_RANGE.start(); - - // find last lit bit - for bit_index in Self::ALLOWED_LEN_RANGE.rev() { - if bs.get(bit_index as usize).unwrap_or_default() { - min_bits = bit_index + 1; - break; - } - } - min_bits - }; - - let last_byte_bits = (min_bits % 8) as u8; + let last_byte_bits = (#max_expected_fields % 8) as u8; let bs = ::der::asn1::BitStringRef::new(8 - last_byte_bits, &arr)?; bs.encode_value(writer) }