diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 9768a3741b..37f89afcc2 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -420,80 +420,49 @@ fn derive_known_layout_inner(ast: &DeriveInput, _top_level: Trait) -> Result { // A bound on the trailing field is not required, since enums cannot // currently be unsized. - impl_block( - ast, - enm, - Trait::KnownLayout, - FieldBounds::None, - SelfBounds::SIZED, - None, - Some(inner_extras), - outer_extras, - ) + ImplBlockBuilder::new(ast, enm, Trait::KnownLayout, FieldBounds::None) + .self_type_trait_bounds(SelfBounds::SIZED) + .inner_extras(inner_extras) + .outer_extras(outer_extras) + .build() } Data::Union(unn) => { // A bound on the trailing field is not required, since unions // cannot currently be unsized. - impl_block( - ast, - unn, - Trait::KnownLayout, - FieldBounds::None, - SelfBounds::SIZED, - None, - Some(inner_extras), - outer_extras, - ) + ImplBlockBuilder::new(ast, unn, Trait::KnownLayout, FieldBounds::None) + .self_type_trait_bounds(SelfBounds::SIZED) + .inner_extras(inner_extras) + .outer_extras(outer_extras) + .build() } }) } fn derive_no_cell_inner(ast: &DeriveInput, _top_level: Trait) -> TokenStream { match &ast.data { - Data::Struct(strct) => impl_block( - ast, - strct, - Trait::Immutable, - FieldBounds::ALL_SELF, - SelfBounds::None, - None, - None, - None, - ), - Data::Enum(enm) => impl_block( - ast, - enm, - Trait::Immutable, - FieldBounds::ALL_SELF, - SelfBounds::None, - None, - None, - None, - ), - Data::Union(unn) => impl_block( - ast, - unn, - Trait::Immutable, - FieldBounds::ALL_SELF, - SelfBounds::None, - None, - None, - None, - ), + Data::Struct(strct) => { + ImplBlockBuilder::new(ast, strct, Trait::Immutable, FieldBounds::ALL_SELF).build() + } + Data::Enum(enm) => { + ImplBlockBuilder::new(ast, enm, Trait::Immutable, FieldBounds::ALL_SELF).build() + } + Data::Union(unn) => { + ImplBlockBuilder::new(ast, unn, Trait::Immutable, FieldBounds::ALL_SELF).build() + } } } @@ -685,16 +654,9 @@ fn derive_try_from_bytes_struct( } ) }); - Ok(impl_block( - ast, - strct, - Trait::TryFromBytes, - FieldBounds::ALL_SELF, - SelfBounds::None, - None, - Some(extras), - None, - )) + Ok(ImplBlockBuilder::new(ast, strct, Trait::TryFromBytes, FieldBounds::ALL_SELF) + .inner_extras(extras) + .build()) } /// A union is `TryFromBytes` if: @@ -755,16 +717,9 @@ fn derive_try_from_bytes_union( } ) }); - impl_block( - ast, - unn, - Trait::TryFromBytes, - field_type_trait_bounds, - SelfBounds::None, - None, - Some(extras), - None, - ) + ImplBlockBuilder::new(ast, unn, Trait::TryFromBytes, field_type_trait_bounds) + .inner_extras(extras) + .build() } fn derive_try_from_bytes_enum( @@ -792,16 +747,9 @@ fn derive_try_from_bytes_enum( (None, false) => r#enum::derive_is_bit_valid(&ast.ident, &repr, &ast.generics, enm)?, }; - Ok(impl_block( - ast, - enm, - Trait::TryFromBytes, - FieldBounds::ALL_SELF, - SelfBounds::None, - None, - Some(extra), - None, - )) + Ok(ImplBlockBuilder::new(ast, enm, Trait::TryFromBytes, FieldBounds::ALL_SELF) + .inner_extras(extra) + .build()) } /// Attempts to generate a `TryFromBytes::is_bit_valid` instance that @@ -897,16 +845,7 @@ unsafe fn gen_trivial_is_bit_valid_unchecked() -> proc_macro2::TokenStream { /// A struct is `FromZeros` if: /// - all fields are `FromZeros` fn derive_from_zeros_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { - impl_block( - ast, - strct, - Trait::FromZeros, - FieldBounds::ALL_SELF, - SelfBounds::None, - None, - None, - None, - ) + ImplBlockBuilder::new(ast, strct, Trait::FromZeros, FieldBounds::ALL_SELF).build() } /// Returns `Ok(index)` if variant `index` of the enum has a discriminant of @@ -1034,16 +973,8 @@ fn derive_from_zeros_enum(ast: &DeriveInput, enm: &DataEnum) -> Result>(); - Ok(impl_block( - ast, - enm, - Trait::FromZeros, - FieldBounds::Explicit(explicit_bounds), - SelfBounds::None, - None, - None, - None, - )) + Ok(ImplBlockBuilder::new(ast, enm, Trait::FromZeros, FieldBounds::Explicit(explicit_bounds)) + .build()) } /// Unions are `FromZeros` if @@ -1053,31 +984,13 @@ fn derive_from_zeros_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { // compatibility with `derive(TryFromBytes)` on unions; not for soundness. let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf, TraitBound::Other(Trait::Immutable)]); - impl_block( - ast, - unn, - Trait::FromZeros, - field_type_trait_bounds, - SelfBounds::None, - None, - None, - None, - ) + ImplBlockBuilder::new(ast, unn, Trait::FromZeros, field_type_trait_bounds).build() } /// A struct is `FromBytes` if: /// - all fields are `FromBytes` fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { - impl_block( - ast, - strct, - Trait::FromBytes, - FieldBounds::ALL_SELF, - SelfBounds::None, - None, - None, - None, - ) + ImplBlockBuilder::new(ast, strct, Trait::FromBytes, FieldBounds::ALL_SELF).build() } /// An enum is `FromBytes` if: @@ -1109,16 +1022,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> Result TokenStream { // compatibility with `derive(TryFromBytes)` on unions; not for soundness. let field_type_trait_bounds = FieldBounds::All(&[TraitBound::Slf, TraitBound::Other(Trait::Immutable)]); - impl_block( - ast, - unn, - Trait::FromBytes, - field_type_trait_bounds, - SelfBounds::None, - None, - None, - None, - ) + ImplBlockBuilder::new(ast, unn, Trait::FromBytes, field_type_trait_bounds).build() } fn derive_into_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> Result { @@ -1219,16 +1114,9 @@ fn derive_into_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> Result Result Result Result Result impl '_ + Iterator( - input: &DeriveInput, - data: &D, +struct ImplBlockBuilder<'a, D: DataExt> { + input: &'a DeriveInput, + data: &'a D, trt: Trait, - field_type_trait_bounds: FieldBounds, - self_type_trait_bounds: SelfBounds, + field_type_trait_bounds: FieldBounds<'a>, + self_type_trait_bounds: SelfBounds<'a>, padding_check: Option, inner_extras: Option, outer_extras: Option, -) -> TokenStream { - // In this documentation, we will refer to this hypothetical struct: - // - // #[derive(FromBytes)] - // struct Foo - // where - // T: Copy, - // I: Clone, - // I::Item: Clone, - // { - // a: u8, - // b: T, - // c: I::Item, - // } - // - // We extract the field types, which in this case are `u8`, `T`, and - // `I::Item`. We re-use the existing parameters and where clauses. If - // `require_trait_bound == true` (as it is for `FromBytes), we add where - // bounds for each field's type: - // - // impl FromBytes for Foo - // where - // T: Copy, - // I: Clone, - // I::Item: Clone, - // T: FromBytes, - // I::Item: FromBytes, - // { - // } - // - // NOTE: It is standard practice to only emit bounds for the type parameters - // themselves, not for field types based on those parameters (e.g., `T` vs - // `T::Foo`). For a discussion of why this is standard practice, see - // https://github.com/rust-lang/rust/issues/26925. - // - // The reason we diverge from this standard is that doing it that way for us - // would be unsound. E.g., consider a type, `T` where `T: FromBytes` but - // `T::Foo: !FromBytes`. It would not be sound for us to accept a type with - // a `T::Foo` field as `FromBytes` simply because `T: FromBytes`. - // - // While there's no getting around this requirement for us, it does have the - // pretty serious downside that, when lifetimes are involved, the trait - // solver ties itself in knots: - // - // #[derive(Unaligned)] - // #[repr(C)] - // struct Dup<'a, 'b> { - // a: PhantomData<&'a u8>, - // b: PhantomData<&'b u8>, - // } - // - // error[E0283]: type annotations required: cannot resolve `core::marker::PhantomData<&'a u8>: zerocopy::Unaligned` - // --> src/main.rs:6:10 - // | - // 6 | #[derive(Unaligned)] - // | ^^^^^^^^^ - // | - // = note: required by `zerocopy::Unaligned` - - let type_ident = &input.ident; - let trait_path = trt.crate_path(); - let fields = data.fields(); - let variants = data.variants(); - let tag = data.tag(); - - fn bound_tt(ty: &Type, traits: impl Iterator) -> WherePredicate { - let traits = traits.map(|t| t.crate_path()); - parse_quote!(#ty: #(#traits)+*) +} + +impl<'a, D: DataExt> ImplBlockBuilder<'a, D> { + fn new( + input: &'a DeriveInput, + data: &'a D, + trt: Trait, + field_type_trait_bounds: FieldBounds<'a>, + ) -> Self { + Self { + input, + data, + trt, + field_type_trait_bounds, + self_type_trait_bounds: SelfBounds::None, + padding_check: None, + inner_extras: None, + outer_extras: None, + } } - let field_type_bounds: Vec<_> = match (field_type_trait_bounds, &fields[..]) { - (FieldBounds::All(traits), _) => fields - .iter() - .map(|(_vis, _name, ty)| bound_tt(ty, normalize_bounds(trt, traits))) - .collect(), - (FieldBounds::None, _) | (FieldBounds::Trailing(..), []) => vec![], - (FieldBounds::Trailing(traits), [.., last]) => { - vec![bound_tt(last.2, normalize_bounds(trt, traits))] + + fn self_type_trait_bounds(mut self, self_type_trait_bounds: SelfBounds<'a>) -> Self { + self.self_type_trait_bounds = self_type_trait_bounds; + self + } + + fn padding_check>>(mut self, padding_check: P) -> Self { + self.padding_check = padding_check.into(); + self + } + + fn inner_extras(mut self, inner_extras: TokenStream) -> Self { + self.inner_extras = Some(inner_extras); + self + } + + fn outer_extras>>(mut self, outer_extras: T) -> Self { + self.outer_extras = outer_extras.into(); + self + } + + fn build(self) -> TokenStream { + // In this documentation, we will refer to this hypothetical struct: + // + // #[derive(FromBytes)] + // struct Foo + // where + // T: Copy, + // I: Clone, + // I::Item: Clone, + // { + // a: u8, + // b: T, + // c: I::Item, + // } + // + // We extract the field types, which in this case are `u8`, `T`, and + // `I::Item`. We re-use the existing parameters and where clauses. If + // `require_trait_bound == true` (as it is for `FromBytes), we add where + // bounds for each field's type: + // + // impl FromBytes for Foo + // where + // T: Copy, + // I: Clone, + // I::Item: Clone, + // T: FromBytes, + // I::Item: FromBytes, + // { + // } + // + // NOTE: It is standard practice to only emit bounds for the type + // parameters themselves, not for field types based on those parameters + // (e.g., `T` vs `T::Foo`). For a discussion of why this is standard + // practice, see https://github.com/rust-lang/rust/issues/26925. + // + // The reason we diverge from this standard is that doing it that way + // for us would be unsound. E.g., consider a type, `T` where `T: + // FromBytes` but `T::Foo: !FromBytes`. It would not be sound for us to + // accept a type with a `T::Foo` field as `FromBytes` simply because `T: + // FromBytes`. + // + // While there's no getting around this requirement for us, it does have + // the pretty serious downside that, when lifetimes are involved, the + // trait solver ties itself in knots: + // + // #[derive(Unaligned)] + // #[repr(C)] + // struct Dup<'a, 'b> { + // a: PhantomData<&'a u8>, + // b: PhantomData<&'b u8>, + // } + // + // error[E0283]: type annotations required: cannot resolve `core::marker::PhantomData<&'a u8>: zerocopy::Unaligned` + // --> src/main.rs:6:10 + // | + // 6 | #[derive(Unaligned)] + // | ^^^^^^^^^ + // | + // = note: required by `zerocopy::Unaligned` + + let type_ident = &self.input.ident; + let trait_path = self.trt.crate_path(); + let fields = self.data.fields(); + let variants = self.data.variants(); + let tag = self.data.tag(); + + fn bound_tt(ty: &Type, traits: impl Iterator) -> WherePredicate { + let traits = traits.map(|t| t.crate_path()); + parse_quote!(#ty: #(#traits)+*) } - (FieldBounds::Explicit(bounds), _) => bounds, - }; + let field_type_bounds: Vec<_> = match (self.field_type_trait_bounds, &fields[..]) { + (FieldBounds::All(traits), _) => fields + .iter() + .map(|(_vis, _name, ty)| bound_tt(ty, normalize_bounds(self.trt, traits))) + .collect(), + (FieldBounds::None, _) | (FieldBounds::Trailing(..), []) => vec![], + (FieldBounds::Trailing(traits), [.., last]) => { + vec![bound_tt(last.2, normalize_bounds(self.trt, traits))] + } + (FieldBounds::Explicit(bounds), _) => bounds, + }; - // Don't bother emitting a padding check if there are no fields. - #[allow(unstable_name_collisions)] // See `BoolExt` below - // Work around https://github.com/rust-lang/rust-clippy/issues/12280 - #[allow(clippy::incompatible_msrv)] - let padding_check_bound = - padding_check.and_then(|check| (!fields.is_empty()).then_some(check)).map(|check| { - let variant_types = variants.iter().map(|var| { - let types = var.iter().map(|(_vis, _name, ty)| ty); - quote!([#(#types),*]) + // Don't bother emitting a padding check if there are no fields. + #[allow(unstable_name_collisions)] // See `BoolExt` below + // Work around https://github.com/rust-lang/rust-clippy/issues/12280 + #[allow(clippy::incompatible_msrv)] + let padding_check_bound = self + .padding_check + .and_then(|check| (!fields.is_empty()).then_some(check)) + .map(|check| { + let variant_types = variants.iter().map(|var| { + let types = var.iter().map(|(_vis, _name, ty)| ty); + quote!([#(#types),*]) + }); + let validator_context = check.validator_macro_context(); + let validator_macro = check.validator_macro_ident(); + let t = tag.iter(); + parse_quote! { + (): ::zerocopy::util::macro_util::PaddingFree< + Self, + { + #validator_context + ::zerocopy::#validator_macro!(Self, #(#t,)* #(#variant_types),*) + } + > + } }); - let validator_context = check.validator_macro_context(); - let validator_macro = check.validator_macro_ident(); - let t = tag.iter(); - parse_quote! { - (): ::zerocopy::util::macro_util::PaddingFree< - Self, - { - #validator_context - ::zerocopy::#validator_macro!(Self, #(#t,)* #(#variant_types),*) - } - > - } - }); - let self_bounds: Option = match self_type_trait_bounds { - SelfBounds::None => None, - SelfBounds::All(traits) => Some(bound_tt(&parse_quote!(Self), traits.iter().copied())), - }; + let self_bounds: Option = match self.self_type_trait_bounds { + SelfBounds::None => None, + SelfBounds::All(traits) => Some(bound_tt(&parse_quote!(Self), traits.iter().copied())), + }; - let bounds = input - .generics - .where_clause - .as_ref() - .map(|where_clause| where_clause.predicates.iter()) - .into_iter() - .flatten() - .chain(field_type_bounds.iter()) - .chain(padding_check_bound.iter()) - .chain(self_bounds.iter()); - - // The parameters with trait bounds, but without type defaults. - let params = input.generics.params.clone().into_iter().map(|mut param| { - match &mut param { - GenericParam::Type(ty) => ty.default = None, - GenericParam::Const(cnst) => cnst.default = None, - GenericParam::Lifetime(_) => {} - } - quote!(#param) - }); + let bounds = self + .input + .generics + .where_clause + .as_ref() + .map(|where_clause| where_clause.predicates.iter()) + .into_iter() + .flatten() + .chain(field_type_bounds.iter()) + .chain(padding_check_bound.iter()) + .chain(self_bounds.iter()); + + // The parameters with trait bounds, but without type defaults. + let params = self.input.generics.params.clone().into_iter().map(|mut param| { + match &mut param { + GenericParam::Type(ty) => ty.default = None, + GenericParam::Const(cnst) => cnst.default = None, + GenericParam::Lifetime(_) => {} + } + quote!(#param) + }); - // The identifiers of the parameters without trait bounds or type defaults. - let param_idents = input.generics.params.iter().map(|param| match param { - GenericParam::Type(ty) => { - let ident = &ty.ident; - quote!(#ident) - } - GenericParam::Lifetime(l) => { - let ident = &l.lifetime; - quote!(#ident) - } - GenericParam::Const(cnst) => { - let ident = &cnst.ident; - quote!({#ident}) - } - }); + // The identifiers of the parameters without trait bounds or type + // defaults. + let param_idents = self.input.generics.params.iter().map(|param| match param { + GenericParam::Type(ty) => { + let ident = &ty.ident; + quote!(#ident) + } + GenericParam::Lifetime(l) => { + let ident = &l.lifetime; + quote!(#ident) + } + GenericParam::Const(cnst) => { + let ident = &cnst.ident; + quote!({#ident}) + } + }); - let impl_tokens = quote! { - // TODO(#553): Add a test that generates a warning when - // `#[allow(deprecated)]` isn't present. - #[allow(deprecated)] - // While there are not currently any warnings that this suppresses (that - // we're aware of), it's good future-proofing hygiene. - #[automatically_derived] - unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* > - where - #(#bounds,)* - { - fn only_derive_is_allowed_to_implement_this_trait() {} + let inner_extras = self.inner_extras; + let impl_tokens = quote! { + // TODO(#553): Add a test that generates a warning when + // `#[allow(deprecated)]` isn't present. + #[allow(deprecated)] + // While there are not currently any warnings that this suppresses + // (that we're aware of), it's good future-proofing hygiene. + #[automatically_derived] + unsafe impl < #(#params),* > #trait_path for #type_ident < #(#param_idents),* > + where + #(#bounds,)* + { + fn only_derive_is_allowed_to_implement_this_trait() {} - #inner_extras - } - }; + #inner_extras + } + }; - if let Some(outer_extras) = outer_extras { - // So that any items defined in `#outer_extras` don't conflict with - // existing names defined in this scope. - quote! { - const _: () = { - #impl_tokens + if let Some(outer_extras) = self.outer_extras { + // So that any items defined in `#outer_extras` don't conflict with + // existing names defined in this scope. + quote! { + const _: () = { + #impl_tokens - #outer_extras - }; + #outer_extras + }; + } + } else { + impl_tokens } - } else { - impl_tokens } }