diff --git a/der/tests/derive.rs b/der/tests/derive.rs index fa2e1c7e9..510e9a621 100644 --- a/der/tests/derive.rs +++ b/der/tests/derive.rs @@ -11,10 +11,29 @@ // TODO: fix needless_question_mark in the derive crate #![allow(clippy::bool_assert_comparison, clippy::needless_question_mark)] +#[derive(Debug)] +#[allow(dead_code)] +pub struct CustomError(der::Error); + +impl From for CustomError { + fn from(value: der::Error) -> Self { + Self(value) + } +} + +impl From for CustomError { + fn from(_value: std::convert::Infallible) -> Self { + unreachable!() + } +} + /// Custom derive test cases for the `Choice` macro. mod choice { + use super::CustomError; + /// `Choice` with `EXPLICIT` tagging. mod explicit { + use super::CustomError; use der::{ asn1::{GeneralizedTime, UtcTime}, Choice, Decode, Encode, SliceWriter, @@ -50,6 +69,13 @@ mod choice { } } + #[derive(Choice)] + #[asn1(error = CustomError)] + pub enum WithCustomError { + #[asn1(type = "GeneralizedTime")] + Foo(GeneralizedTime), + } + const UTC_TIMESTAMP_DER: &[u8] = &hex!("17 0d 39 31 30 35 30 36 32 33 34 35 34 30 5a"); const GENERAL_TIMESTAMP_DER: &[u8] = &hex!("18 0f 31 39 39 31 30 35 30 36 32 33 34 35 34 30 5a"); @@ -61,6 +87,10 @@ mod choice { let general_time = Time::from_der(GENERAL_TIMESTAMP_DER).unwrap(); assert_eq!(general_time.to_unix_duration().as_secs(), 673573540); + + let WithCustomError::Foo(with_custom_error) = + WithCustomError::from_der(GENERAL_TIMESTAMP_DER).unwrap(); + assert_eq!(with_custom_error.to_unix_duration().as_secs(), 673573540); } #[test] @@ -154,6 +184,7 @@ mod choice { /// Custom derive test cases for the `Enumerated` macro. mod enumerated { + use super::CustomError; use der::{Decode, Encode, Enumerated, SliceWriter}; use hex_literal::hex; @@ -176,6 +207,14 @@ mod enumerated { const UNSPECIFIED_DER: &[u8] = &hex!("0a 01 00"); const KEY_COMPROMISE_DER: &[u8] = &hex!("0a 01 01"); + #[derive(Enumerated, Copy, Clone, Eq, PartialEq, Debug)] + #[asn1(error = CustomError)] + #[repr(u32)] + pub enum EnumWithCustomError { + Unspecified = 0, + Specified = 1, + } + #[test] fn decode() { let unspecified = CrlReason::from_der(UNSPECIFIED_DER).unwrap(); @@ -183,6 +222,9 @@ mod enumerated { let key_compromise = CrlReason::from_der(KEY_COMPROMISE_DER).unwrap(); assert_eq!(CrlReason::KeyCompromise, key_compromise); + + let custom_error_enum = EnumWithCustomError::from_der(UNSPECIFIED_DER).unwrap(); + assert_eq!(custom_error_enum, EnumWithCustomError::Unspecified); } #[test] @@ -202,6 +244,7 @@ mod enumerated { /// Custom derive test cases for the `Sequence` macro. #[cfg(feature = "oid")] mod sequence { + use super::CustomError; use core::marker::PhantomData; use der::{ asn1::{AnyRef, ObjectIdentifier, SetOf}, @@ -383,6 +426,12 @@ mod sequence { pub typed_context_specific_optional: Option<&'a [u8]>, } + #[derive(Sequence)] + #[asn1(error = CustomError)] + pub struct TypeWithCustomError { + pub simple: bool, + } + #[test] fn idp_test() { let idp = IssuingDistributionPointExample::from_der(&hex!("30038101FF")).unwrap(); @@ -444,6 +493,9 @@ mod sequence { PRIME256V1_OID, ObjectIdentifier::try_from(algorithm_identifier.parameters.unwrap()).unwrap() ); + + let t = TypeWithCustomError::from_der(&hex!("30030101FF")).unwrap(); + assert!(t.simple); } #[test] diff --git a/der_derive/src/attributes.rs b/der_derive/src/attributes.rs index fa050cbcb..911adcbd6 100644 --- a/der_derive/src/attributes.rs +++ b/der_derive/src/attributes.rs @@ -2,11 +2,33 @@ use crate::{Asn1Type, Tag, TagMode, TagNumber}; use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{quote, ToTokens}; use std::{fmt::Debug, str::FromStr}; use syn::punctuated::Punctuated; use syn::{parse::Parse, parse::ParseStream, Attribute, Ident, LitStr, Path, Token}; +/// Error type used by the structure +#[derive(Debug, Clone, Default, Eq, PartialEq)] +pub(crate) enum ErrorType { + /// Represents the ::der::Error type + #[default] + Der, + /// Represents an error designed by Path + Custom(Path), +} + +impl ToTokens for ErrorType { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::Der => { + let err = quote! { ::der::Error }; + err.to_tokens(tokens) + } + Self::Custom(path) => path.to_tokens(tokens), + } + } +} + /// Attribute name. pub(crate) const ATTR_NAME: &str = "asn1"; @@ -18,37 +40,47 @@ pub(crate) struct TypeAttrs { /// /// The default value is `EXPLICIT`. pub tag_mode: TagMode, + pub error: ErrorType, } impl TypeAttrs { /// Parse attributes from a struct field or enum variant. pub fn parse(attrs: &[Attribute]) -> syn::Result { let mut tag_mode = None; + let mut error = None; - let mut parsed_attrs = Vec::new(); - AttrNameValue::from_attributes(attrs, &mut parsed_attrs)?; - - for attr in parsed_attrs { - // `tag_mode = "..."` attribute - let mode = attr.parse_value("tag_mode")?.ok_or_else(|| { - syn::Error::new_spanned( - &attr.name, - "invalid `asn1` attribute (valid options are `tag_mode`)", - ) - })?; - - if tag_mode.is_some() { - return Err(syn::Error::new_spanned( - &attr.name, - "duplicate ASN.1 `tag_mode` attribute", - )); + attrs.iter().try_for_each(|attr| { + if !attr.path().is_ident(ATTR_NAME) { + return Ok(()); } - tag_mode = Some(mode); - } + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("tag_mode") { + if tag_mode.is_some() { + abort!(attr, "duplicate ASN.1 `tag_mode` attribute"); + } + + tag_mode = Some(meta.value()?.parse()?); + } else if meta.path.is_ident("error") { + if error.is_some() { + abort!(attr, "duplicate ASN.1 `error` attribute"); + } + + error = Some(ErrorType::Custom(meta.value()?.parse()?)); + } else { + return Err(syn::Error::new_spanned( + attr, + "invalid `asn1` attribute (valid options are `tag_mode` and `error`)", + )); + } + + Ok(()) + }) + })?; Ok(Self { tag_mode: tag_mode.unwrap_or_default(), + error: error.unwrap_or_default(), }) } } diff --git a/der_derive/src/choice.rs b/der_derive/src/choice.rs index 8683c6441..8cd50ca01 100644 --- a/der_derive/src/choice.rs +++ b/der_derive/src/choice.rs @@ -5,9 +5,9 @@ mod variant; use self::variant::ChoiceVariant; -use crate::{default_lifetime, TypeAttrs}; +use crate::{default_lifetime, ErrorType, TypeAttrs}; use proc_macro2::TokenStream; -use quote::quote; +use quote::{quote, ToTokens}; use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; /// Derive the `Choice` trait for an enum. @@ -20,6 +20,9 @@ pub(crate) struct DeriveChoice { /// Variants of this `Choice`. variants: Vec, + + /// Error type for `DecodeValue` implementation. + error: ErrorType, } impl DeriveChoice { @@ -44,6 +47,7 @@ impl DeriveChoice { ident: input.ident, generics: input.generics.clone(), variants, + error: type_attrs.error.clone(), }) } @@ -84,6 +88,8 @@ impl DeriveChoice { tagged_body.push(variant.to_tagged_tokens()); } + let error = self.error.to_token_stream(); + quote! { impl #impl_generics ::der::Choice<#lifetime> for #ident #ty_generics #where_clause { fn can_decode(tag: ::der::Tag) -> bool { @@ -92,17 +98,20 @@ impl DeriveChoice { } impl #impl_generics ::der::Decode<#lifetime> for #ident #ty_generics #where_clause { - type Error = ::der::Error; + type Error = #error; - fn decode>(reader: &mut R) -> ::der::Result { + fn decode>(reader: &mut R) -> ::core::result::Result { use der::Reader as _; match ::der::Tag::peek(reader)? { #(#decode_body)* - actual => Err(der::ErrorKind::TagUnexpected { - expected: None, - actual - } - .into()), + actual => Err(::der::Error::new( + ::der::ErrorKind::TagUnexpected { + expected: None, + actual + }, + reader.position() + ).into() + ), } } } diff --git a/der_derive/src/enumerated.rs b/der_derive/src/enumerated.rs index 303014140..849a43081 100644 --- a/der_derive/src/enumerated.rs +++ b/der_derive/src/enumerated.rs @@ -2,11 +2,10 @@ //! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to //! enum variants. -use crate::attributes::AttrNameValue; -use crate::{default_lifetime, ATTR_NAME}; +use crate::{default_lifetime, ErrorType, ATTR_NAME}; use proc_macro2::TokenStream; -use quote::quote; -use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant}; +use quote::{quote, ToTokens}; +use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, LitStr, Path, Variant}; /// Valid options for the `#[repr]` attribute on `Enumerated` types. const REPR_TYPES: &[&str] = &["u8", "u16", "u32"]; @@ -24,6 +23,9 @@ pub(crate) struct DeriveEnumerated { /// Variants of this enum. variants: Vec, + + /// Error type for `DecodeValue` implementation. + error: ErrorType, } impl DeriveEnumerated { @@ -40,22 +42,30 @@ impl DeriveEnumerated { // Reject `asn1` attributes, parse the `repr` attribute let mut repr: Option = None; let mut integer = false; + let mut error: Option = None; for attr in &input.attrs { if attr.path().is_ident(ATTR_NAME) { - let kvs = match AttrNameValue::parse_attribute(attr) { - Ok(kvs) => kvs, - Err(e) => abort!(attr, e), - }; - for anv in kvs { - if anv.name.is_ident("type") { - match anv.value.value().as_str() { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("type") { + let value: LitStr = meta.value()?.parse()?; + match value.value().as_str() { "ENUMERATED" => integer = false, "INTEGER" => integer = true, - s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")), + s => abort!(value, format_args!("`type = \"{s}\"` is unsupported")), } + } else if meta.path.is_ident("error") { + let path: Path = meta.value()?.parse()?; + error = Some(ErrorType::Custom(path)); + } else { + return Err(syn::Error::new_spanned( + &meta.path, + "invalid `asn1` attribute (valid options are `type` and `error`)", + )); } - } + + Ok(()) + })?; } else if attr.path().is_ident("repr") { if repr.is_some() { abort!( @@ -97,6 +107,7 @@ impl DeriveEnumerated { })?, variants, integer, + error: error.unwrap_or_default(), }) } @@ -115,14 +126,16 @@ impl DeriveEnumerated { try_from_body.push(variant.to_try_from_tokens()); } + let error = self.error.to_token_stream(); + quote! { impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident { - type Error = ::der::Error; + type Error = #error; fn decode_value>( reader: &mut R, header: ::der::Header - ) -> ::der::Result { + ) -> ::core::result::Result { <#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into() } } @@ -142,12 +155,12 @@ impl DeriveEnumerated { } impl TryFrom<#repr> for #ident { - type Error = ::der::Error; + type Error = #error; - fn try_from(n: #repr) -> ::der::Result { + fn try_from(n: #repr) -> ::core::result::Result { match n { #(#try_from_body)* - _ => Err(#tag.value_error()) + _ => Err(#tag.value_error().into()) } } } diff --git a/der_derive/src/lib.rs b/der_derive/src/lib.rs index 79e73b663..c87cd902a 100644 --- a/der_derive/src/lib.rs +++ b/der_derive/src/lib.rs @@ -45,6 +45,21 @@ //! The default is `EXPLICIT`, so the attribute only needs to be added when //! a particular module is declared `IMPLICIT`. //! +//! ### `#[asn1(error = ...)]` attribute: custom error types for decoding +//! +//! By default generated `Decode` / `DecodeValue` implementations generated by macros +//! from this crate use `der::Error` as the generic `Error` parameter, but it's +//! possible to use a custom error type that implements `From` by using +//! this attribute. +//! +//! Note that [`Choice`] puts more restrictions on the error type: during decoding +//! for each enum variant the type in its `#[asn1(type = "...")]` attribute (let's +//! call it `T`) is constructed and then converted to the actual variant's type +//! (this one will be `U`) using the `TryInto` trait. That means that for each enum +//! variant's type `U` the custom error type must implement +//! `From<>::Error>`. Since `U` and `T` types are usually the same +//! implementing `From` should do it. +//! //! ## Field-level attributes //! //! The following attributes can be added to either the fields of a particular @@ -144,7 +159,7 @@ mod value_ord; use crate::{ asn1_type::Asn1Type, - attributes::{FieldAttrs, TypeAttrs, ATTR_NAME}, + attributes::{ErrorType, FieldAttrs, TypeAttrs, ATTR_NAME}, choice::DeriveChoice, enumerated::DeriveEnumerated, sequence::DeriveSequence, diff --git a/der_derive/src/sequence.rs b/der_derive/src/sequence.rs index 81ca3d729..f347c727f 100644 --- a/der_derive/src/sequence.rs +++ b/der_derive/src/sequence.rs @@ -3,10 +3,10 @@ mod field; -use crate::{default_lifetime, TypeAttrs}; +use crate::{default_lifetime, ErrorType, TypeAttrs}; use field::SequenceField; use proc_macro2::TokenStream; -use quote::quote; +use quote::{quote, ToTokens}; use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; /// Derive the `Sequence` trait for a struct @@ -19,6 +19,9 @@ pub(crate) struct DeriveSequence { /// Fields of the struct. fields: Vec, + + /// Error type for `DecodeValue` implementation. + error: ErrorType, } impl DeriveSequence { @@ -44,6 +47,7 @@ impl DeriveSequence { ident: input.ident, generics: input.generics.clone(), fields, + error: type_attrs.error.clone(), }) } @@ -84,14 +88,16 @@ impl DeriveSequence { encode_fields.push(quote!(#field.encode(writer)?;)); } + let error = self.error.to_token_stream(); + quote! { impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause { - type Error = ::der::Error; + type Error = #error; fn decode_value>( reader: &mut R, header: ::der::Header, - ) -> ::der::Result { + ) -> ::core::result::Result { use ::der::{Decode as _, DecodeValue as _, Reader as _}; reader.read_nested(header.length, |reader| { diff --git a/der_derive/src/tag.rs b/der_derive/src/tag.rs index aab2899b5..a1cf529cb 100644 --- a/der_derive/src/tag.rs +++ b/der_derive/src/tag.rs @@ -7,6 +7,7 @@ use std::{ fmt::{self, Display}, str::FromStr, }; +use syn::{parse::Parse, LitStr}; /// Tag "IR" type. #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] @@ -78,6 +79,21 @@ impl TagMode { } } +impl Parse for TagMode { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let s: LitStr = input.parse()?; + + match s.value().as_str() { + "EXPLICIT" | "explicit" => Ok(TagMode::Explicit), + "IMPLICIT" | "implicit" => Ok(TagMode::Implicit), + _ => Err(syn::Error::new( + s.span(), + "invalid tag mode (supported modes are `EXPLICIT` and `IMPLICIT`)", + )), + } + } +} + impl FromStr for TagMode { type Err = ParseError;