From b003b38a36b717a05abeae285f40bbe77d475f05 Mon Sep 17 00:00:00 2001 From: konstin Date: Sun, 21 Jul 2019 14:08:20 +0200 Subject: [PATCH 1/2] POC: Dunder/Protocols without specialization --- examples/rustapi_module/setup.py | 5 ++ examples/rustapi_module/src/dunder.rs | 25 +++++++ examples/rustapi_module/src/lib.rs | 1 + .../rustapi_module/tests/test_datetime.py | 3 +- examples/rustapi_module/tests/test_dunder.py | 5 ++ pyo3-derive-backend/src/pyclass.rs | 27 +++++++ pyo3-derive-backend/src/pyimpl.rs | 57 +++++++++++++- src/class/methods.rs | 74 ++++++++++++++++++- src/class/number.rs | 22 ++++-- src/type_object.rs | 7 +- 10 files changed, 212 insertions(+), 14 deletions(-) create mode 100644 examples/rustapi_module/src/dunder.rs create mode 100644 examples/rustapi_module/tests/test_dunder.py diff --git a/examples/rustapi_module/setup.py b/examples/rustapi_module/setup.py index 2dbaebb047d..fb51a5b19c4 100644 --- a/examples/rustapi_module/setup.py +++ b/examples/rustapi_module/setup.py @@ -102,6 +102,11 @@ def get_py_version_cfgs(): "Cargo.toml", rustc_flags=get_py_version_cfgs(), ), + RustExtension( + "rustapi_module.dunder", + "Cargo.toml", + rustc_flags=get_py_version_cfgs(), + ), ], install_requires=install_requires, tests_require=tests_require, diff --git a/examples/rustapi_module/src/dunder.rs b/examples/rustapi_module/src/dunder.rs new file mode 100644 index 00000000000..2ec7de11768 --- /dev/null +++ b/examples/rustapi_module/src/dunder.rs @@ -0,0 +1,25 @@ +use pyo3::prelude::*; + +#[pymodule] +fn dunder(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} + +#[pyclass] +pub struct Number { + value: u32, +} + +#[pymethods] +impl Number { + #[new] + fn new(obj: &PyRawObject, value: u32) { + obj.init(Number { value }) + } + + /// Very basic add function + fn __add__(&self, other: u32) -> PyResult { + Ok(self.value + other) + } +} diff --git a/examples/rustapi_module/src/lib.rs b/examples/rustapi_module/src/lib.rs index be7fd6ed63c..e728b74b101 100644 --- a/examples/rustapi_module/src/lib.rs +++ b/examples/rustapi_module/src/lib.rs @@ -1,4 +1,5 @@ pub mod datetime; pub mod dict_iter; +pub mod dunder; pub mod othermod; pub mod subclassing; diff --git a/examples/rustapi_module/tests/test_datetime.py b/examples/rustapi_module/tests/test_datetime.py index 71f8c55301c..da40c0b21ba 100644 --- a/examples/rustapi_module/tests/test_datetime.py +++ b/examples/rustapi_module/tests/test_datetime.py @@ -8,7 +8,6 @@ from hypothesis import strategies as st from hypothesis.strategies import dates, datetimes - # Constants def _get_utc(): timezone = getattr(pdt, "timezone", None) @@ -310,4 +309,4 @@ def test_tz_class_introspection(): tzi = rdt.TzClass() assert tzi.__class__ == rdt.TzClass - assert repr(tzi).startswith(" TokenStream { // it comes up in error messages let name = cls.to_string() + "GeneratedPyo3Inventory"; let inventory_cls = syn::Ident::new(&name, Span::call_site()); + let protocol_name = cls.to_string() + "GeneratedPyo3InventoryProtocol"; + let protocol_inventory_cls = syn::Ident::new(&protocol_name, Span::call_site()); quote! { #[doc(hidden)] @@ -241,6 +243,31 @@ fn impl_inventory(cls: &syn::Ident) -> TokenStream { } pyo3::inventory::collect!(#inventory_cls); + + // Dunder methods/Protocol support + + #[doc(hidden)] + pub struct #protocol_inventory_cls { + methods: &'static [pyo3::methods::protocols::PyProcotolMethodWrapped], + } + + impl pyo3::class::methods::protocols::PyProtocolInventory for #protocol_inventory_cls { + fn new(methods: &'static [pyo3::methods::protocols::PyProcotolMethodWrapped]) -> Self { + Self { + methods + } + } + + fn get_methods(&self) -> &'static [pyo3::methods::protocols::PyProcotolMethodWrapped] { + self.methods + } + } + + impl pyo3::class::methods::protocols::PyProtocolInventoryDispatch for #cls { + type ProtocolInventoryType = #protocol_inventory_cls; + } + + pyo3::inventory::collect!(#protocol_inventory_cls); } } diff --git a/pyo3-derive-backend/src/pyimpl.rs b/pyo3-derive-backend/src/pyimpl.rs index 64abab02637..13a3271b768 100644 --- a/pyo3-derive-backend/src/pyimpl.rs +++ b/pyo3-derive-backend/src/pyimpl.rs @@ -20,12 +20,60 @@ pub fn build_py_methods(ast: &mut syn::ItemImpl) -> syn::Result { } } +fn binary_func_protocol_wrap(ty: &syn::Type, name: &syn::Ident) -> TokenStream { + quote! {{ + #[allow(unused_mut)] + unsafe extern "C" fn wrap( + lhs: *mut pyo3::ffi::PyObject, + rhs: *mut pyo3::ffi::PyObject, + ) -> *mut pyo3::ffi::PyObject { + use pyo3::ObjectProtocol; + let _pool = pyo3::GILPool::new(); + let py = pyo3::Python::assume_gil_acquired(); + let lhs = py.from_borrowed_ptr::(lhs); + let rhs = py.from_borrowed_ptr::(rhs); + + let result = match lhs.extract() { + Ok(lhs) => match rhs.extract() { + Ok(rhs) => #ty::#name(lhs, rhs).into(), + Err(e) => Err(e.into()), + }, + Err(e) => Err(e.into()), + }; + pyo3::callback::cb_convert(pyo3::callback::PyObjectCallbackConverter, py, result) + } + pyo3::class::methods::protocols::PyProcotolMethodWrapped::Add(wrap) + }} +} + pub fn impl_methods(ty: &syn::Type, impls: &mut Vec) -> syn::Result { // get method names in impl block let mut methods = Vec::new(); + let mut protocol_methods = Vec::new(); for iimpl in impls.iter_mut() { if let syn::ImplItem::Method(ref mut meth) = iimpl { let name = meth.sig.ident.clone(); + + if name.to_string().starts_with("__") && name.to_string().ends_with("__") { + #[allow(clippy::single_match)] + { + match name.to_string().as_str() { + "__add__" => { + protocol_methods.push(binary_func_protocol_wrap(&ty, &name)); + } + _ => { + // This currently breaks the tests + /* + return Err(syn::Error::new_spanned( + meth.sig.ident.clone(), + "Unknown dunder method", + )) + */ + } + } + } + } + methods.push(pymethod::gen_py_method( ty, &name, @@ -36,11 +84,18 @@ pub fn impl_methods(ty: &syn::Type, impls: &mut Vec) -> syn::Resu } Ok(quote! { - pyo3::inventory::submit! { + pyo3::inventory::submit! { #![crate = pyo3] { type TyInventory = <#ty as pyo3::class::methods::PyMethodsInventoryDispatch>::InventoryType; ::new(&[#(#methods),*]) } } + + pyo3::inventory::submit! { + #![crate = pyo3] { + type ProtocolInventory = <#ty as pyo3::class::methods::protocols::PyProtocolInventoryDispatch>::ProtocolInventoryType; + ::new(&[#(#protocol_methods),*]) + } + } }) } diff --git a/src/class/methods.rs b/src/class/methods.rs index 362f78bcfd5..f3c3a9f41ca 100644 --- a/src/class/methods.rs +++ b/src/class/methods.rs @@ -117,7 +117,7 @@ impl PySetterDef { } #[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code -/// This trait is implemented for all pyclass so to implement the [PyMethodsProtocol] +/// This trait is implemented for all pyclass to implement the [PyMethodsProtocol] /// through inventory pub trait PyMethodsInventoryDispatch { /// This allows us to get the inventory type when only the pyclass is in scope @@ -153,3 +153,75 @@ where .collect() } } + +/// Utils to define and collect dunder methods, powered by inventory +pub mod protocols { + use crate::ffi; + + #[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code + /// The c wrapper around a dunder method defined in an impl block + pub enum PyProcotolMethodWrapped { + Add(ffi::binaryfunc), + } + + #[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code + /// All defined dunder methods collected into a single struct + #[derive(Default)] + pub struct PyProcolTypes { + pub(crate) add: Option, + } + + impl PyProcolTypes { + /// Returns whether any dunder method has been defined + pub fn any_defined(&self) -> bool { + self.add.is_some() + } + } + + #[doc(hidden)] // Only to be used through the proc macros, use PyMethodsProtocol in custom code + /// This trait is implemented for all pyclass to implement the [PyProtocolInventory] + /// through inventory + pub trait PyProtocolInventoryDispatch { + /// This allows us to get the inventory type when only the pyclass is in scope + type ProtocolInventoryType: PyProtocolInventory; + } + + #[doc(hidden)] + /// Allows arbitrary pymethod blocks to submit dunder methods, which are eventually collected + /// into [PyProcolTypes] + pub trait PyProtocolInventory: inventory::Collect { + fn new(methods: &'static [PyProcotolMethodWrapped]) -> Self; + fn get_methods(&self) -> &'static [PyProcotolMethodWrapped]; + } + + /// Defines which protocols this class implements + pub trait PyProtocol { + /// Returns all methods that are defined for a class + fn py_protocols() -> PyProcolTypes; + } + + impl PyProtocol for T + where + T: PyProtocolInventoryDispatch, + { + /// Collects all defined dunder methods into a single [PyProcolTypes] instance + fn py_protocols() -> PyProcolTypes { + let mut py_protocol_types = PyProcolTypes::default(); + let flattened = inventory::iter:: + .into_iter() + .flat_map(PyProtocolInventory::get_methods); + for method in flattened { + match method { + PyProcotolMethodWrapped::Add(add) => { + if py_protocol_types.add.is_some() { + panic!("You can't define `__add__` more than once"); + } + py_protocol_types.add = Some(*add); + } + } + } + + py_protocol_types + } + } +} diff --git a/src/class/number.rs b/src/class/number.rs index d95cb61ff57..9839349358f 100644 --- a/src/class/number.rs +++ b/src/class/number.rs @@ -5,6 +5,7 @@ use crate::callback::PyObjectCallbackConverter; use crate::class::basic::PyObjectProtocolImpl; +use crate::class::methods::protocols::PyProtocol; use crate::class::methods::PyMethodDef; use crate::err::PyResult; use crate::ffi; @@ -621,7 +622,7 @@ pub trait PyNumberIndexProtocol<'p>: PyNumberProtocol<'p> { } #[doc(hidden)] -pub trait PyNumberProtocolImpl: PyObjectProtocolImpl { +pub trait PyNumberProtocolImpl: PyObjectProtocolImpl + PyProtocol { fn methods() -> Vec { Vec::new() } @@ -629,6 +630,13 @@ pub trait PyNumberProtocolImpl: PyObjectProtocolImpl { if let Some(nb_bool) = ::nb_bool_fn() { let meth = ffi::PyNumberMethods { nb_bool: Some(nb_bool), + nb_add: Self::py_protocols().add, + ..ffi::PyNumberMethods_INIT + }; + Some(meth) + } else if Self::py_protocols().any_defined() { + let meth = ffi::PyNumberMethods { + nb_add: Self::py_protocols().add, ..ffi::PyNumberMethods_INIT }; Some(meth) @@ -638,11 +646,11 @@ pub trait PyNumberProtocolImpl: PyObjectProtocolImpl { } } -impl<'p, T> PyNumberProtocolImpl for T {} +impl<'p, T> PyNumberProtocolImpl for T where T: PyProtocol {} impl<'p, T> PyNumberProtocolImpl for T where - T: PyNumberProtocol<'p>, + T: PyNumberProtocol<'p> + PyProtocol, { fn tp_as_number() -> Option { Some(ffi::PyNumberMethods { @@ -742,17 +750,17 @@ where } } -trait PyNumberAddProtocolImpl { +trait PyNumberAddProtocolImpl: PyProtocol { fn nb_add() -> Option { - None + Self::py_protocols().add } } -impl<'p, T> PyNumberAddProtocolImpl for T where T: PyNumberProtocol<'p> {} +impl<'p, T> PyNumberAddProtocolImpl for T where T: PyNumberProtocol<'p> + PyProtocol {} impl PyNumberAddProtocolImpl for T where - T: for<'p> PyNumberAddProtocol<'p>, + T: for<'p> PyNumberAddProtocol<'p> + PyProtocol, { fn nb_add() -> Option { py_binary_num_func!( diff --git a/src/type_object.rs b/src/type_object.rs index 878558a7e03..f4d324bffcb 100644 --- a/src/type_object.rs +++ b/src/type_object.rs @@ -2,6 +2,7 @@ //! Python type object information +use crate::class::methods::protocols::PyProtocol; use crate::class::methods::PyMethodDefType; use crate::err::{PyErr, PyResult}; use crate::instance::{Py, PyNativeType}; @@ -249,7 +250,7 @@ pub unsafe trait PyTypeObject { unsafe impl PyTypeObject for T where - T: PyTypeInfo + PyMethodsProtocol + PyObjectAlloc, + T: PyTypeInfo + PyMethodsProtocol + PyObjectAlloc + PyProtocol, { fn init_type() -> NonNull { let type_object = unsafe { ::type_object() }; @@ -297,7 +298,7 @@ impl PyTypeCreate for T where T: PyObjectAlloc + PyTypeObject + Sized {} #[cfg(not(Py_LIMITED_API))] pub fn initialize_type(py: Python, module_name: Option<&str>) -> PyResult<*mut ffi::PyTypeObject> where - T: PyObjectAlloc + PyTypeInfo + PyMethodsProtocol, + T: PyObjectAlloc + PyTypeInfo + PyMethodsProtocol + PyProtocol, { let type_object: &mut ffi::PyTypeObject = unsafe { T::type_object() }; let base_type_object: &mut ffi::PyTypeObject = @@ -438,7 +439,7 @@ fn py_class_flags(type_object: &mut ffi::PyTypeObject) { } } -fn py_class_method_defs() -> ( +fn py_class_method_defs() -> ( Option, Option, Option, From 566a70ce3f41408c666225c7e8b3138eb8661ac3 Mon Sep 17 00:00:00 2001 From: konstin Date: Sun, 21 Jul 2019 14:08:26 +0200 Subject: [PATCH 2/2] Remove an unsafe --- src/ffi3/object.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ffi3/object.rs b/src/ffi3/object.rs index a20e4c77fdb..1fd32c19153 100644 --- a/src/ffi3/object.rs +++ b/src/ffi3/object.rs @@ -299,7 +299,7 @@ mod typeobject { impl Default for PyNumberMethods { #[inline] fn default() -> Self { - unsafe { mem::zeroed() } + PyNumberMethods_INIT } } macro_rules! as_expr {