From 26c9f7255a0e12e48132ac3dbf0f8519ed3243dc Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 23 Mar 2026 11:27:06 +0100 Subject: [PATCH] Allow applying autodiff macros to trait functions. It will use enzyme to generate a default derivative implementation, which can be overwritten by the user. --- .../src/attributes/autodiff.rs | 1 + compiler/rustc_builtin_macros/src/autodiff.rs | 30 +++++++------ compiler/rustc_expand/src/base.rs | 4 +- tests/pretty/autodiff/trait.pp | 43 +++++++++++++++++++ tests/pretty/autodiff/trait.rs | 32 ++++++++++++++ 5 files changed, 94 insertions(+), 16 deletions(-) create mode 100644 tests/pretty/autodiff/trait.pp create mode 100644 tests/pretty/autodiff/trait.rs diff --git a/compiler/rustc_attr_parsing/src/attributes/autodiff.rs b/compiler/rustc_attr_parsing/src/attributes/autodiff.rs index 118a4103b1a96..c72ff224a1502 100644 --- a/compiler/rustc_attr_parsing/src/attributes/autodiff.rs +++ b/compiler/rustc_attr_parsing/src/attributes/autodiff.rs @@ -24,6 +24,7 @@ impl SingleAttributeParser for RustcAutodiffParser { Allow(Target::Fn), Allow(Target::Method(MethodKind::Inherent)), Allow(Target::Method(MethodKind::Trait { body: true })), + Allow(Target::Method(MethodKind::Trait { body: false })), Allow(Target::Method(MethodKind::TraitImpl)), ]); const TEMPLATE: AttributeTemplate = template!( diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 30391e74480fe..afa393a545cd4 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -224,16 +224,18 @@ mod llvm_enzyme { } _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { of_trait: _ }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( - assoc_item.vis.clone(), - sig.clone(), - ident.clone(), - generics.clone(), - true, - )), - _ => None, - }, + Annotatable::AssocItem(assoc_item, _ctxt @ (Impl { of_trait: _ } | Trait)) => { + match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( + assoc_item.vis.clone(), + sig.clone(), + ident.clone(), + generics.clone(), + true, + )), + _ => None, + } + } _ => None, }) else { dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); @@ -393,14 +395,14 @@ mod llvm_enzyme { } Annotatable::Item(iitem.clone()) } - Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => { + Annotatable::AssocItem(ref mut assoc_item, ctxt @ (Impl { .. } | Trait)) => { if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { assoc_item.attrs.push(attr); } if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { has_inline_never = true; } - Annotatable::AssocItem(assoc_item.clone(), i) + Annotatable::AssocItem(assoc_item.clone(), ctxt) } Annotatable::Stmt(ref mut stmt) => { match stmt.kind { @@ -441,7 +443,7 @@ mod llvm_enzyme { } let d_annotatable = match &item { - Annotatable::AssocItem(_, _) => { + Annotatable::AssocItem(_, ctxt) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn); let d_fn = Box::new(ast::AssocItem { attrs: d_attrs, @@ -451,7 +453,7 @@ mod llvm_enzyme { kind: assoc_item, tokens: None, }); - Annotatable::AssocItem(d_fn, Impl { of_trait: false }) + Annotatable::AssocItem(d_fn, *ctxt) } Annotatable::Item(_) => { let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn)); diff --git a/compiler/rustc_expand/src/base.rs b/compiler/rustc_expand/src/base.rs index 225906dfba2de..44f148f9c15aa 100644 --- a/compiler/rustc_expand/src/base.rs +++ b/compiler/rustc_expand/src/base.rs @@ -149,14 +149,14 @@ impl Annotatable { pub fn expect_trait_item(self) -> Box { match self { Annotatable::AssocItem(i, AssocCtxt::Trait) => i, - _ => panic!("expected Item"), + _ => panic!("expected trait item"), } } pub fn expect_impl_item(self) -> Box { match self { Annotatable::AssocItem(i, AssocCtxt::Impl { .. }) => i, - _ => panic!("expected Item"), + _ => panic!("expected impl item"), } } diff --git a/tests/pretty/autodiff/trait.pp b/tests/pretty/autodiff/trait.pp new file mode 100644 index 0000000000000..cb2fe0c4a4efa --- /dev/null +++ b/tests/pretty/autodiff/trait.pp @@ -0,0 +1,43 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Just check it does not crash for now +// CHECK: ; +#![feature(autodiff)] +#![feature(core_intrinsics)] +#![feature(rustc_attrs)] + +use std::autodiff::autodiff_reverse; + +struct Foo { + a: f64, +} + +trait MyTrait { + #[rustc_autodiff] + fn f(&self, x: f64) -> f64; + #[rustc_autodiff(Reverse, 1, Const, Active, Active)] + fn df(&self, x: f64, seed: f64) -> (f64, f64) { + std::hint::black_box(seed); + std::hint::black_box(x); + ::std::intrinsics::autodiff( + Self::f as for<'a> fn(&'a Self, _: f64) -> f64, + Self::df, + (self, x, seed), + ) + + } +} + +impl MyTrait for Foo { + fn f(&self, x: f64) -> f64 { + x.sin() + } +} + +fn main() { + let foo = Foo { a: 3.0f64 }; + dbg!(foo.df(2.0, 1.0)); + dbg!(2.0_f64.cos()); +} diff --git a/tests/pretty/autodiff/trait.rs b/tests/pretty/autodiff/trait.rs new file mode 100644 index 0000000000000..a308cf3fb2adb --- /dev/null +++ b/tests/pretty/autodiff/trait.rs @@ -0,0 +1,32 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Just check it does not crash for now +// CHECK: ; +#![feature(autodiff)] +#![feature(core_intrinsics)] +#![feature(rustc_attrs)] + +use std::autodiff::autodiff_reverse; + +struct Foo { + a: f64, +} + +trait MyTrait { + #[autodiff_reverse(df, Const, Active, Active)] + fn f(&self, x: f64) -> f64; +} + +impl MyTrait for Foo { + fn f(&self, x: f64) -> f64 { + x.sin() + } +} + +fn main() { + let foo = Foo { a: 3.0f64 }; + dbg!(foo.df(2.0, 1.0)); + dbg!(2.0_f64.cos()); +}