Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion der/src/asn1/bit_string.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! ASN.1 `BIT STRING` support.

pub mod fixed_len_bit_string;
pub mod allowed_len_bit_string;

use crate::{
BytesRef, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, FixedTag, Header, Length, Reader,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ use crate::{Error, ErrorKind, Tag};
/// flag4: bool,
/// }
/// ```
pub trait FixedLenBitString {
pub trait AllowedLenBitString {
/// Implementer must specify how many bits are allowed
const ALLOWED_LEN_RANGE: RangeInclusive<u16>;

/// Returns an error if the bitstring is not in expected length range
fn check_bit_len(bit_len: u16) -> Result<(), Error> {
let allowed_len_range = Self::ALLOWED_LEN_RANGE;

// forces allowed range to eg. 3..=4
// forces allowed range to e.g. 3..=4
if !allowed_len_range.contains(&bit_len) {
Err(ErrorKind::Length {
tag: Tag::BitString,
Expand Down
2 changes: 1 addition & 1 deletion der/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ mod document;
mod str_owned;

pub use crate::{
asn1::bit_string::fixed_len_bit_string::FixedLenBitString,
asn1::bit_string::allowed_len_bit_string::AllowedLenBitString,
asn1::{AnyRef, Choice, Sequence},
datetime::DateTime,
decode::{Decode, DecodeOwned, DecodeValue},
Expand Down
113 changes: 108 additions & 5 deletions der/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,9 @@ mod bitstring {
assert_eq!(reencoded, BITSTRING_EXAMPLE);
}

/// this BitString will allow only 3..=4 bits
/// this BitString will allow only 3..=4 bits in Decode
///
/// but will always Encode 4 bits
#[derive(BitString)]
pub struct MyBitString3or4 {
pub bit_0: bool,
Expand Down Expand Up @@ -765,8 +767,8 @@ mod bitstring {
.to_der()
.unwrap();

// 3 bits used, 5 unused
assert_eq!(encoded_3_zeros, hex!("03 02 05 00"));
// 4 bits used, 4 unused
assert_eq!(encoded_3_zeros, hex!("03 02 04 00"));
}

#[test]
Expand All @@ -780,8 +782,8 @@ mod bitstring {
.to_der()
.unwrap();

// 3 bits used, 5 unused
assert_eq!(encoded_3_zeros, hex!("03 02 05 E0"));
// 4 bits used, 4 unused
assert_eq!(encoded_3_zeros, hex!("03 02 04 E0"));
}

#[test]
Expand Down Expand Up @@ -813,6 +815,107 @@ mod bitstring {
// 4 bits used, 4 unused
assert_eq!(encoded_4_zeros, hex!("03 02 04 10"));
}

/// ```asn1
/// PasswordFlags ::= BIT STRING {
/// case-sensitive (0),
/// local (1),
/// change-disabled (2),
/// unblock-disabled (3),
/// initialized (4),
/// needs-padding (5),
/// unblockingPassword (6),
/// soPassword (7),
/// disable-allowed (8),
/// integrity-protected (9),
/// confidentiality-protected (10),
/// exchangeRefData (11),
/// resetRetryCounter1 (12),
/// resetRetryCounter2 (13),
/// context-dependent (14),
/// multiStepProtocol (15)
/// }
/// ```
#[derive(Clone, Debug, Eq, PartialEq, BitString)]
pub struct PasswordFlags {
/// case-sensitive (0)
pub case_sensitive: bool,

/// local (1)
pub local: bool,

/// change-disabled (2)
pub change_disabled: bool,

/// unblock-disabled (3)
pub unblock_disabled: bool,

/// initialized (4)
pub initialized: bool,

/// needs-padding (5)
pub needs_padding: bool,

/// unblockingPassword (6)
pub unblocking_password: bool,

/// soPassword (7)
pub so_password: bool,

/// disable-allowed (8)
pub disable_allowed: bool,

/// integrity-protected (9)
pub integrity_protected: bool,

/// confidentiality-protected (10)
pub confidentiality_protected: bool,

/// exchangeRefData (11)
pub exchange_ref_data: bool,

/// Second edition 2016-05-15
/// resetRetryCounter1 (12)
#[asn1(optional = "true")]
pub reset_retry_counter1: bool,

/// resetRetryCounter2 (13)
#[asn1(optional = "true")]
pub reset_retry_counter2: bool,

/// context-dependent (14)
#[asn1(optional = "true")]
pub context_dependent: bool,

/// multiStepProtocol (15)
#[asn1(optional = "true")]
pub multi_step_protocol: bool,

/// fake_bit_for_testing (16)
#[asn1(optional = "true")]
pub fake_bit_for_testing: bool,
}

const PASS_FLAGS_EXAMPLE_IN: &[u8] = &hex!("03 03 04 FF FF");
const PASS_FLAGS_EXAMPLE_OUT: &[u8] = &hex!("03 04 07 FF F0 00");

#[test]
fn decode_short_bitstring_2_bytes() {
let pass_flags = PasswordFlags::from_der(PASS_FLAGS_EXAMPLE_IN).unwrap();

// case-sensitive (0)
assert!(pass_flags.case_sensitive);

// exchangeRefData (11)
assert!(pass_flags.exchange_ref_data);

// resetRetryCounter1 (12)
assert!(!pass_flags.reset_retry_counter1);

let reencoded = pass_flags.to_der().unwrap();

assert_eq!(reencoded, PASS_FLAGS_EXAMPLE_OUT);
}
}
mod infer_default {
//! When another crate might define a PartialEq for another type, the use of
Expand Down
52 changes: 26 additions & 26 deletions der_derive/src/bitstring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,26 @@ impl DeriveBitString {

let type_attrs = TypeAttrs::parse(&input.attrs)?;

let fields = data
let fields: Vec<_> = data
.fields
.iter()
.map(|field| BitStringField::new(field, &type_attrs))
.collect::<syn::Result<_>>()?;

let mut started_optionals = false;
for field in &fields {
if !field.attrs.optional {
if started_optionals {
abort!(
input.ident,
"derive `BitString` only supports trailing optional fields one after another",
)
}
} else {
started_optionals = true;
}
}

Ok(Self {
ident: input.ident,
generics: input.generics.clone(),
Expand Down Expand Up @@ -75,14 +89,18 @@ impl DeriveBitString {

let mut min_expected_fields: u16 = 0;
let mut max_expected_fields: u16 = 0;
let mut started_optionals = false;
for field in &self.fields {
max_expected_fields += 1;

if !field.attrs.optional {
if field.attrs.optional {
started_optionals = true;
}
if !started_optionals {
min_expected_fields += 1;
}
}
let min_expected_bytes = (min_expected_fields + 7) / 8;
let max_expected_bytes = (max_expected_fields + 7) / 8;

for (i, field) in self.fields.iter().enumerate().rev() {
let field_name = &field.ident;
Expand Down Expand Up @@ -115,7 +133,7 @@ impl DeriveBitString {
impl ::der::FixedTag for #ident #ty_generics #where_clause {
const TAG: der::Tag = ::der::Tag::BitString;
}
impl ::der::FixedLenBitString for #ident #ty_generics #where_clause {
impl ::der::AllowedLenBitString for #ident #ty_generics #where_clause {
const ALLOWED_LEN_RANGE: ::core::ops::RangeInclusive<u16> = #min_expected_fields..=#max_expected_fields;
}

Expand All @@ -127,7 +145,7 @@ impl DeriveBitString {
header: ::der::Header,
) -> ::core::result::Result<Self, ::der::Error> {
use ::der::{Decode as _, DecodeValue as _, Reader as _};
use ::der::FixedLenBitString as _;
use ::der::AllowedLenBitString as _;


let bs = ::der::asn1::BitStringRef::decode_value(reader, header)?;
Expand All @@ -147,33 +165,15 @@ impl DeriveBitString {

impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause {
fn value_len(&self) -> der::Result<der::Length> {
Ok(der::Length::new(#min_expected_bytes + 1))
Ok(der::Length::new(#max_expected_bytes + 1))
}

fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> {
use ::der::Encode as _;
use der::FixedLenBitString as _;
use ::der::AllowedLenBitString as _;

let arr = [#(#encode_bytes),*];

let min_bits = {
let max_bits = *Self::ALLOWED_LEN_RANGE.end();
let last_byte_bits = (max_bits % 8) as u8;
let bs = ::der::asn1::BitStringRef::new(8 - last_byte_bits, &arr)?;

let mut min_bits = *Self::ALLOWED_LEN_RANGE.start();

// find last lit bit
for bit_index in Self::ALLOWED_LEN_RANGE.rev() {
if bs.get(bit_index as usize).unwrap_or_default() {
min_bits = bit_index + 1;
break;
}
}
min_bits
};

let last_byte_bits = (min_bits % 8) as u8;
let last_byte_bits = (#max_expected_fields % 8) as u8;
let bs = ::der::asn1::BitStringRef::new(8 - last_byte_bits, &arr)?;
bs.encode_value(writer)
}
Expand Down