diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index 51df95b4097..eca9f1f2b4b 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -407,6 +407,10 @@ impl ClassWithGCSupport { } ``` +Usually, an implementation of `__traverse__` should do nothing but calls to `visit.call`. +Most importantly, safe access to the GIL is prohibited inside implementations of `__traverse__`, +i.e. `Python::with_gil` will panic. + > Note: these methods are part of the C API, PyPy does not necessarily honor them. If you are building for PyPy you should measure memory consumption to make sure you do not have runaway memory growth. See [this issue on the PyPy bug tracker](https://foss.heptapod.net/pypy/pypy/-/issues/3899). [`IterNextOutput`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.IterNextOutput.html diff --git a/newsfragments/3168.changed.md b/newsfragments/3168.changed.md new file mode 100644 index 00000000000..a00f7eded30 --- /dev/null +++ b/newsfragments/3168.changed.md @@ -0,0 +1 @@ +Safe access to the GIL, for example via `Python::with_gil`, is now locked inside of implementations of the `__traverse__` slot. diff --git a/newsfragments/3168.fixed.md b/newsfragments/3168.fixed.md new file mode 100644 index 00000000000..395b1acb424 --- /dev/null +++ b/newsfragments/3168.fixed.md @@ -0,0 +1 @@ +Do not apply deferred reference count updates when entering a `__traverse__` implementation is it cannot alter any reference counts while the garbage collector is running. diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index bca3dab747c..9689d863a44 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -406,22 +406,8 @@ fn impl_traverse_slot(cls: &syn::Type, rust_fn_ident: &syn::Ident) -> MethodAndS slf: *mut _pyo3::ffi::PyObject, visit: _pyo3::ffi::visitproc, arg: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int - { - let trap = _pyo3::impl_::panic::PanicTrap::new("uncaught panic inside __traverse__ handler"); - let pool = _pyo3::GILPool::new(); - let py = pool.python(); - let slf = py.from_borrowed_ptr::<_pyo3::PyCell<#cls>>(slf); - - let visit = _pyo3::class::gc::PyVisit::from_raw(visit, arg, py); - let borrow = slf.try_borrow(); - let retval = if let ::std::result::Result::Ok(borrow) = borrow { - _pyo3::impl_::pymethods::unwrap_traverse_result(borrow.#rust_fn_ident(visit)) - } else { - 0 - }; - trap.disarm(); - retval + ) -> ::std::os::raw::c_int { + _pyo3::impl_::pymethods::call_traverse_impl::<#cls>(slf, #cls::#rust_fn_ident, visit, arg) } }; let slot_def = quote! { diff --git a/src/gil.rs b/src/gil.rs index 903f89d70d6..fa96a910797 100644 --- a/src/gil.rs +++ b/src/gil.rs @@ -29,12 +29,17 @@ thread_local_const_init! { /// they are dropped. /// /// As a result, if this thread has the GIL, GIL_COUNT is greater than zero. - static GIL_COUNT: Cell = const { Cell::new(0) }; + /// + /// Additionally, we sometimes need to prevent safe access to the GIL, + /// e.g. when implementing `__traverse__`, which is represented by a negative value. + static GIL_COUNT: Cell = const { Cell::new(0) }; /// Temporarily hold objects that will be released when the GILPool drops. static OWNED_OBJECTS: RefCell>> = const { RefCell::new(Vec::new()) }; } +const GIL_LOCKED_DURING_TRAVERSE: isize = -1; + /// Checks whether the GIL is acquired. /// /// Note: This uses pyo3's internal count rather than PyGILState_Check for two reasons: @@ -286,7 +291,7 @@ static POOL: ReferencePool = ReferencePool::new(); /// A guard which can be used to temporarily release the GIL and restore on `Drop`. pub(crate) struct SuspendGIL { - count: usize, + count: isize, tstate: *mut ffi::PyThreadState, } @@ -311,6 +316,40 @@ impl Drop for SuspendGIL { } } +/// Used to lock safe access to the GIL +pub(crate) struct LockGIL { + count: isize, +} + +impl LockGIL { + /// Lock access to the GIL while an implementation of `__traverse__` is running + pub fn during_traverse() -> Self { + Self::new(GIL_LOCKED_DURING_TRAVERSE) + } + + fn new(reason: isize) -> Self { + let count = GIL_COUNT.with(|c| c.replace(reason)); + + Self { count } + } + + #[cold] + fn bail(current: isize) { + match current { + GIL_LOCKED_DURING_TRAVERSE => panic!( + "Access to the GIL is prohibited while a __traverse__ implmentation is running." + ), + _ => panic!("Access to the GIL is currently prohibited."), + } + } +} + +impl Drop for LockGIL { + fn drop(&mut self) { + GIL_COUNT.with(|c| c.set(self.count)); + } +} + /// A RAII pool which PyO3 uses to store owned Python references. /// /// See the [Memory Management] chapter of the guide for more information about how PyO3 uses @@ -421,7 +460,13 @@ pub unsafe fn register_owned(_py: Python<'_>, obj: NonNull) { #[inline(always)] fn increment_gil_count() { // Ignores the error in case this function called from `atexit`. - let _ = GIL_COUNT.try_with(|c| c.set(c.get() + 1)); + let _ = GIL_COUNT.try_with(|c| { + let current = c.get(); + if current < 0 { + LockGIL::bail(current); + } + c.set(current + 1); + }); } /// Decrements pyo3's internal GIL count - to be called whenever GILPool or GILGuard is dropped. diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index fde97430001..60db3fbb1bf 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -1,9 +1,15 @@ +use crate::gil::LockGIL; +use crate::impl_::panic::PanicTrap; use crate::internal_tricks::extract_c_string; -use crate::{ffi, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, PyTraverseError, Python}; +use crate::{ + ffi, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, + Python, +}; use std::borrow::Cow; use std::ffi::CStr; use std::fmt; -use std::os::raw::c_int; +use std::os::raw::{c_int, c_void}; +use std::panic::{catch_unwind, AssertUnwindSafe}; /// Python 3.8 and up - __ipow__ has modulo argument correctly populated. #[cfg(Py_3_8)] @@ -239,14 +245,46 @@ impl PySetterDef { } } -/// Unwraps the result of __traverse__ for tp_traverse +/// Calls an implementation of __traverse__ for tp_traverse #[doc(hidden)] -#[inline] -pub fn unwrap_traverse_result(result: Result<(), PyTraverseError>) -> c_int { - match result { - Ok(()) => 0, - Err(PyTraverseError(value)) => value, - } +pub unsafe fn call_traverse_impl( + slf: *mut ffi::PyObject, + impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>, + visit: ffi::visitproc, + arg: *mut c_void, +) -> c_int +where + T: PyClass, +{ + // It is important the implementation of `__traverse__` cannot safely access the GIL, + // c.f. https://github.com/PyO3/pyo3/issues/3165, and hence we do not expose our GIL + // token to the user code and lock safe methods for acquiring the GIL. + // (This includes enforcing the `&self` method receiver as e.g. `PyRef` could + // reconstruct a GIL token via `PyRef::py`.) + // Since we do not create a `GILPool` at all, it is important that our usage of the GIL + // token does not produce any owned objects thereby calling into `register_owned`. + let trap = PanicTrap::new("uncaught panic inside __traverse__ handler"); + + let py = Python::assume_gil_acquired(); + let slf = py.from_borrowed_ptr::>(slf); + let borrow = slf.try_borrow(); + let visit = PyVisit::from_raw(visit, arg, py); + + let retval = if let Ok(borrow) = borrow { + let _lock = LockGIL::during_traverse(); + + match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) { + Ok(res) => match res { + Ok(()) => 0, + Err(PyTraverseError(value)) => value, + }, + Err(_err) => -1, + } + } else { + 0 + }; + trap.disarm(); + retval } pub(crate) struct PyMethodDefDestructor { diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index ffef87ee2fe..9a28f4e564f 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -118,6 +118,7 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/not_send2.rs"); t.compile_fail("tests/ui/not_send3.rs"); t.compile_fail("tests/ui/get_set_all.rs"); + t.compile_fail("tests/ui/traverse_bare_self.rs"); } #[rustversion::before(1.63)] diff --git a/tests/test_gc.rs b/tests/test_gc.rs index 8d612eb9c60..e8cb65168ca 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -332,3 +332,34 @@ fn traverse_error() { ); }) } + +#[pyclass] +struct TriesGILInTraverse {} + +#[pymethods] +impl TriesGILInTraverse { + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + Python::with_gil(|_py| Ok(())) + } +} + +#[test] +fn tries_gil_in_traverse() { + Python::with_gil(|py| unsafe { + // declare a visitor function which errors (returns nonzero code) + extern "C" fn novisit( + _object: *mut pyo3::ffi::PyObject, + _arg: *mut core::ffi::c_void, + ) -> std::os::raw::c_int { + 0 + } + + // get the traverse function + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + // confirm that traversing panicks + let obj = Py::new(py, TriesGILInTraverse {}).unwrap(); + assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1); + }) +} diff --git a/tests/ui/traverse_bare_self.rs b/tests/ui/traverse_bare_self.rs new file mode 100644 index 00000000000..5adc316e43f --- /dev/null +++ b/tests/ui/traverse_bare_self.rs @@ -0,0 +1,12 @@ +use pyo3::prelude::*; +use pyo3::PyVisit; + +#[pyclass] +struct TraverseTriesToTakePyRef {} + +#[pymethods] +impl TraverseTriesToTakePyRef { + fn __traverse__(slf: PyRef, visit: PyVisit) {} +} + +fn main() {} diff --git a/tests/ui/traverse_bare_self.stderr b/tests/ui/traverse_bare_self.stderr new file mode 100644 index 00000000000..aba76145dc3 --- /dev/null +++ b/tests/ui/traverse_bare_self.stderr @@ -0,0 +1,17 @@ +error[E0308]: mismatched types + --> tests/ui/traverse_bare_self.rs:8:6 + | +7 | #[pymethods] + | ------------ arguments to this function are incorrect +8 | impl TraverseTriesToTakePyRef { + | ______^ +9 | | fn __traverse__(slf: PyRef, visit: PyVisit) {} + | |___________________^ expected fn pointer, found fn item + | + = note: expected fn pointer `for<'a, 'b> fn(&'a TraverseTriesToTakePyRef, PyVisit<'b>) -> Result<(), PyTraverseError>` + found fn item `for<'a, 'b> fn(pyo3::PyRef<'a, TraverseTriesToTakePyRef>, PyVisit<'b>) {TraverseTriesToTakePyRef::__traverse__}` +note: function defined here + --> src/impl_/pymethods.rs + | + | pub unsafe fn call_traverse_impl( + | ^^^^^^^^^^^^^^^^^^