Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion der/src/header.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! ASN.1 DER headers.

use crate::{Decode, DerOrd, Encode, Error, ErrorKind, Length, Reader, Result, Tag, Writer};
use crate::{
Decode, DerOrd, Encode, EncodingRules, Error, ErrorKind, Length, Reader, Result, Tag, Writer,
};
use core::cmp::Ordering;

/// ASN.1 DER headers: tag + length component of TLV-encoded values
Expand Down Expand Up @@ -34,6 +36,7 @@ impl<'a> Decode<'a> for Header {
type Error = Error;

fn decode<R: Reader<'a>>(reader: &mut R) -> Result<Header> {
let is_constructed = Tag::peek_is_constructed(reader)?;
let tag = Tag::decode(reader)?;

let length = Length::decode(reader).map_err(|e| {
Expand All @@ -44,6 +47,11 @@ impl<'a> Decode<'a> for Header {
}
})?;

if length.is_indefinite() && !is_constructed {
debug_assert_eq!(reader.encoding_rules(), EncodingRules::Ber);
return Err(reader.error(ErrorKind::IndefiniteLength));
}

Ok(Self { tag, length })
}
}
Expand Down
109 changes: 72 additions & 37 deletions der/src/length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,27 @@ use core::{
const INDEFINITE_LENGTH_OCTET: u8 = 0b10000000; // 0x80

/// ASN.1-encoded length.
#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct Length(u32);
#[derive(Copy, Clone, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct Length {
/// Inner length as a `u32`. Note that the decoder and encoder also support a maximum length
/// of 32-bits.
inner: u32,

/// Flag bit which specifies whether the length was indeterminate when decoding ASN.1 BER.
///
/// This should always be false when working with DER.
indefinite: bool,
}

impl Length {
/// Length of `0`
pub const ZERO: Self = Self(0);
pub const ZERO: Self = Self::new(0);

/// Length of `1`
pub const ONE: Self = Self(1);
pub const ONE: Self = Self::new(1);

/// Maximum length (`u32::MAX`).
pub const MAX: Self = Self(u32::MAX);
pub const MAX: Self = Self::new(u32::MAX);

/// Maximum number of octets in a DER encoding of a [`Length`] using the
/// rules implemented by this crate.
Expand All @@ -38,8 +47,11 @@ impl Length {
/// Create a new [`Length`] for any value which fits inside of a [`u16`].
///
/// This function is const-safe and therefore useful for [`Length`] constants.
pub const fn new(value: u16) -> Self {
Self(value as u32)
pub const fn new(value: u32) -> Self {
Self {
inner: value,
indefinite: false,
}
}

/// Create a new [`Length`] for any value which fits inside the length type.
Expand All @@ -50,14 +62,18 @@ impl Length {
if len > (u32::MAX as usize) {
Err(Error::from_kind(ErrorKind::Overflow))
} else {
Ok(Length(len as u32))
Ok(Self::new(len as u32))
}
}

/// Is this length equal to zero?
pub const fn is_zero(self) -> bool {
let value = self.0;
value == 0
self.inner == 0
}

/// Was this length decoded from an indefinite length when decoding BER?
pub(crate) const fn is_indefinite(self) -> bool {
self.indefinite
}

/// Get the length of DER Tag-Length-Value (TLV) encoded data if `self`
Expand All @@ -68,12 +84,12 @@ impl Length {

/// Perform saturating addition of two lengths.
pub fn saturating_add(self, rhs: Self) -> Self {
Self(self.0.saturating_add(rhs.0))
Self::new(self.inner.saturating_add(rhs.inner))
}

/// Perform saturating subtraction of two lengths.
pub fn saturating_sub(self, rhs: Self) -> Self {
Self(self.0.saturating_sub(rhs.0))
Self::new(self.inner.saturating_sub(rhs.inner))
}

/// Get initial octet of the encoded length (if one is required).
Expand All @@ -89,7 +105,7 @@ impl Length {
/// > most significant bit;
/// > c) the value 11111111₂ shall not be used.
fn initial_octet(self) -> Option<u8> {
match self.0 {
match self.inner {
0x80..=0xFF => Some(0x81),
0x100..=0xFFFF => Some(0x82),
0x10000..=0xFFFFFF => Some(0x83),
Expand All @@ -103,10 +119,10 @@ impl Add for Length {
type Output = Result<Self>;

fn add(self, other: Self) -> Result<Self> {
self.0
.checked_add(other.0)
self.inner
.checked_add(other.inner)
.ok_or_else(|| ErrorKind::Overflow.into())
.map(Self)
.map(Self::new)
}
}

Expand Down Expand Up @@ -154,10 +170,10 @@ impl Sub for Length {
type Output = Result<Self>;

fn sub(self, other: Length) -> Result<Self> {
self.0
.checked_sub(other.0)
self.inner
.checked_sub(other.inner)
.ok_or_else(|| ErrorKind::Overflow.into())
.map(Self)
.map(Self::new)
}
}

Expand All @@ -171,25 +187,25 @@ impl Sub<Length> for Result<Length> {

impl From<u8> for Length {
fn from(len: u8) -> Length {
Length(len.into())
Length::new(len.into())
}
}

impl From<u16> for Length {
fn from(len: u16) -> Length {
Length(len.into())
Length::new(len.into())
}
}

impl From<u32> for Length {
fn from(len: u32) -> Length {
Length(len)
Length::new(len)
}
}

impl From<Length> for u32 {
fn from(length: Length) -> u32 {
length.0
length.inner
}
}

Expand All @@ -205,7 +221,7 @@ impl TryFrom<Length> for usize {
type Error = Error;

fn try_from(len: Length) -> Result<usize> {
len.0.try_into().map_err(|_| ErrorKind::Overflow.into())
len.inner.try_into().map_err(|_| ErrorKind::Overflow.into())
}
}

Expand Down Expand Up @@ -259,12 +275,12 @@ impl<'a> Decode<'a> for Length {

impl Encode for Length {
fn encoded_len(&self) -> Result<Length> {
match self.0 {
0..=0x7F => Ok(Length(1)),
0x80..=0xFF => Ok(Length(2)),
0x100..=0xFFFF => Ok(Length(3)),
0x10000..=0xFFFFFF => Ok(Length(4)),
0x1000000..=0xFFFFFFFF => Ok(Length(5)),
match self.inner {
0..=0x7F => Ok(Length::new(1)),
0x80..=0xFF => Ok(Length::new(2)),
0x100..=0xFFFF => Ok(Length::new(3)),
0x10000..=0xFFFFFF => Ok(Length::new(4)),
0x1000000..=0xFFFFFFFF => Ok(Length::new(5)),
}
}

Expand All @@ -274,15 +290,15 @@ impl Encode for Length {
writer.write_byte(tag_byte)?;

// Strip leading zeroes
match self.0.to_be_bytes() {
match self.inner.to_be_bytes() {
[0, 0, 0, byte] => writer.write_byte(byte),
[0, 0, bytes @ ..] => writer.write(&bytes),
[0, bytes @ ..] => writer.write(&bytes),
bytes => writer.write(&bytes),
}
}
#[allow(clippy::cast_possible_truncation)]
None => writer.write_byte(self.0 as u8),
None => writer.write_byte(self.inner as u8),
}
}
}
Expand All @@ -302,9 +318,19 @@ impl DerOrd for Length {
}
}

impl fmt::Debug for Length {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.indefinite {
write!(f, "Length({self} [indefinite])")
} else {
f.debug_tuple("Length").field(&self.inner).finish()
}
}
}

impl fmt::Display for Length {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
self.inner.fmt(f)
}
}

Expand All @@ -313,7 +339,7 @@ impl fmt::Display for Length {
#[cfg(feature = "arbitrary")]
impl<'a> arbitrary::Arbitrary<'a> for Length {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
Ok(Self(u.arbitrary()?))
Ok(Self::new(u.arbitrary()?))
}

fn size_hint(depth: usize) -> (usize, Option<usize>) {
Expand Down Expand Up @@ -362,11 +388,15 @@ fn decode_indefinite_length<'a, R: Reader<'a>>(reader: &mut R) -> Result<Length>

// Read the length byte and ensure it's zero (i.e. the full EOC is `00 00`)
let length_byte = reader.read_byte()?;
if length_byte == 0 {
return current_pos - start_pos;
} else {

if length_byte != 0 {
return Err(reader.error(ErrorKind::IndefiniteLength));
}

// Compute how much we read and flag the decoded length as indefinite
let mut ret = (current_pos - start_pos)?;
ret.indefinite = true;
return Ok(ret);
}

let header = Header::decode(reader)?;
Expand Down Expand Up @@ -492,6 +522,9 @@ mod tests {
27 F0 F0 00 00 00 00"
);

// Ensure the indefinite bit isn't set when decoding DER
assert!(!Length::from_der(&[0x00]).unwrap().indefinite);

let mut reader =
SliceReader::new_with_encoding_rules(&EXAMPLE_BER, EncodingRules::Ber).unwrap();

Expand All @@ -501,6 +534,7 @@ mod tests {

// Decode indefinite length
let length = Length::decode(&mut reader).unwrap();
assert!(length.indefinite);

// Decoding the length should leave the position at the end of the indefinite length octet
let pos = usize::try_from(reader.position()).unwrap();
Expand Down Expand Up @@ -530,6 +564,7 @@ mod tests {

// Parse the inner indefinite length
let length = Length::decode(&mut reader).unwrap();
assert!(length.indefinite);
assert_eq!(usize::try_from(length).unwrap(), 18);
}
}
8 changes: 7 additions & 1 deletion der/src/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ impl Tag {
Self::decode(&mut reader.clone())
}

/// Peek at whether the next byte in the reader has the constructed bit set.
pub(crate) fn peek_is_constructed<'a>(reader: &impl Reader<'a>) -> Result<bool> {
let octet = reader.clone().read_byte()?;
Ok(octet & CONSTRUCTED_FLAG != 0)
}

/// Returns true if given context-specific (or any given class) tag number matches the peeked tag.
pub(crate) fn peek_matches<'a, R: Reader<'a>>(
reader: &mut R,
Expand Down Expand Up @@ -400,7 +406,7 @@ impl Encode for Tag {
let length = if number <= 30 {
Length::ONE
} else {
Length::new(number.ilog2() as u16 / 7 + 2)
Length::new(number.ilog2() / 7 + 2)
};

Ok(length)
Expand Down
2 changes: 1 addition & 1 deletion der_derive/src/bitstring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl DeriveBitString {

impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause {
fn value_len(&self) -> der::Result<der::Length> {
Ok(der::Length::new(#max_expected_bytes + 1))
Ok(der::Length::new(#max_expected_bytes as u32 + 1))
}

fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion pkcs5/src/pbes2/kdf/salt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Salt {

Ok(Self {
inner,
length: Length::new(slice.len() as u16),
length: Length::new(slice.len() as u32),
})
}

Expand Down