diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index 51df95b4097..eca9f1f2b4b 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -407,6 +407,10 @@ impl ClassWithGCSupport { } ``` +Usually, an implementation of `__traverse__` should do nothing but calls to `visit.call`. +Most importantly, safe access to the GIL is prohibited inside implementations of `__traverse__`, +i.e. `Python::with_gil` will panic. + > Note: these methods are part of the C API, PyPy does not necessarily honor them. If you are building for PyPy you should measure memory consumption to make sure you do not have runaway memory growth. See [this issue on the PyPy bug tracker](https://foss.heptapod.net/pypy/pypy/-/issues/3899). [`IterNextOutput`]: {{#PYO3_DOCS_URL}}/pyo3/pyclass/enum.IterNextOutput.html diff --git a/newsfragments/3168.changed.md b/newsfragments/3168.changed.md new file mode 100644 index 00000000000..a00f7eded30 --- /dev/null +++ b/newsfragments/3168.changed.md @@ -0,0 +1 @@ +Safe access to the GIL, for example via `Python::with_gil`, is now locked inside of implementations of the `__traverse__` slot. diff --git a/newsfragments/3168.fixed.md b/newsfragments/3168.fixed.md new file mode 100644 index 00000000000..395b1acb424 --- /dev/null +++ b/newsfragments/3168.fixed.md @@ -0,0 +1 @@ +Do not apply deferred reference count updates when entering a `__traverse__` implementation is it cannot alter any reference counts while the garbage collector is running. diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index bca3dab747c..9689d863a44 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -406,22 +406,8 @@ fn impl_traverse_slot(cls: &syn::Type, rust_fn_ident: &syn::Ident) -> MethodAndS slf: *mut _pyo3::ffi::PyObject, visit: _pyo3::ffi::visitproc, arg: *mut ::std::os::raw::c_void, - ) -> ::std::os::raw::c_int - { - let trap = _pyo3::impl_::panic::PanicTrap::new("uncaught panic inside __traverse__ handler"); - let pool = _pyo3::GILPool::new(); - let py = pool.python(); - let slf = py.from_borrowed_ptr::<_pyo3::PyCell<#cls>>(slf); - - let visit = _pyo3::class::gc::PyVisit::from_raw(visit, arg, py); - let borrow = slf.try_borrow(); - let retval = if let ::std::result::Result::Ok(borrow) = borrow { - _pyo3::impl_::pymethods::unwrap_traverse_result(borrow.#rust_fn_ident(visit)) - } else { - 0 - }; - trap.disarm(); - retval + ) -> ::std::os::raw::c_int { + _pyo3::impl_::pymethods::call_traverse_impl::<#cls>(slf, #cls::#rust_fn_ident, visit, arg) } }; let slot_def = quote! { diff --git a/src/gil.rs b/src/gil.rs index aa4ed2e7825..746d8bc885f 100644 --- a/src/gil.rs +++ b/src/gil.rs @@ -17,7 +17,10 @@ thread_local! { /// they are dropped. /// /// As a result, if this thread has the GIL, GIL_COUNT is greater than zero. - static GIL_COUNT: Cell = Cell::new(0); + /// + /// Additionally, we sometimes need to prevent safe access to the GIL, + /// e.g. when implementing `__traverse__`, which is represented by a negative value. + static GIL_COUNT: Cell = Cell::new(0); /// Temporarily hold objects that will be released when the GILPool drops. static OWNED_OBJECTS: RefCell>> = RefCell::new(Vec::with_capacity(256)); @@ -290,7 +293,7 @@ static POOL: ReferencePool = ReferencePool::new(); /// A guard which can be used to temporarily release the GIL and restore on `Drop`. pub(crate) struct SuspendGIL { - count: usize, + count: isize, tstate: *mut ffi::PyThreadState, } @@ -315,6 +318,27 @@ impl Drop for SuspendGIL { } } +/// Used to lock safe access to the GIL, used to implement `__traverse__` +pub struct LockGIL { + count: isize, +} + +impl LockGIL { + /// TODO + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let count = GIL_COUNT.with(|c| c.replace(-1)); + + Self { count } + } +} + +impl Drop for LockGIL { + fn drop(&mut self) { + GIL_COUNT.with(|c| c.set(self.count)); + } +} + /// A RAII pool which PyO3 uses to store owned Python references. /// /// See the [Memory Management] chapter of the guide for more information about how PyO3 uses @@ -425,7 +449,14 @@ pub unsafe fn register_owned(_py: Python<'_>, obj: NonNull) { #[inline(always)] fn increment_gil_count() { // Ignores the error in case this function called from `atexit`. - let _ = GIL_COUNT.try_with(|c| c.set(c.get() + 1)); + let _ = GIL_COUNT.try_with(|c| { + let current = c.get(); + assert!( + current >= 0, + "Access to the GIL is currently prohibited, for example because a `__traverse__` implementation is currently running." + ); + c.set(current + 1); + }); } /// Decrements pyo3's internal GIL count - to be called whenever GILPool or GILGuard is dropped. diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index fde97430001..ad9a20ac01a 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -1,9 +1,15 @@ +use crate::gil::LockGIL; +use crate::impl_::panic::PanicTrap; use crate::internal_tricks::extract_c_string; -use crate::{ffi, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, PyTraverseError, Python}; +use crate::{ + ffi, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, + Python, +}; use std::borrow::Cow; use std::ffi::CStr; use std::fmt; -use std::os::raw::c_int; +use std::os::raw::{c_int, c_void}; +use std::panic::{catch_unwind, AssertUnwindSafe}; /// Python 3.8 and up - __ipow__ has modulo argument correctly populated. #[cfg(Py_3_8)] @@ -239,14 +245,46 @@ impl PySetterDef { } } -/// Unwraps the result of __traverse__ for tp_traverse +/// Calls an implementation of __traverse__ for tp_traverse #[doc(hidden)] -#[inline] -pub fn unwrap_traverse_result(result: Result<(), PyTraverseError>) -> c_int { - match result { - Ok(()) => 0, - Err(PyTraverseError(value)) => value, - } +pub unsafe fn call_traverse_impl( + slf: *mut ffi::PyObject, + impl_: fn(&T, PyVisit<'_>) -> Result<(), PyTraverseError>, + visit: ffi::visitproc, + arg: *mut c_void, +) -> c_int +where + T: PyClass, +{ + // It is important the implementation of `__traverse__` cannot safely access the GIL, + // c.f. https://github.com/PyO3/pyo3/issues/3165, and hence we do not expose our GIL + // token to the user code and lock safe methods for acquiring the GIL. + // (This includes enforcing the `&self` method receiver as e.g. `PyRef` could + // reconstruct a GIL token via `PyRef::py`.) + // Since we do not create a `GILPool` at all, it is important that our usage of the GIL + // token does not produce any owned objects thereby calling into `register_owned`. + let trap = PanicTrap::new("uncaught panic inside __traverse__ handler"); + + let py = Python::assume_gil_acquired(); + let slf = py.from_borrowed_ptr::>(slf); + let borrow = slf.try_borrow(); + let visit = PyVisit::from_raw(visit, arg, py); + + let retval = if let Ok(borrow) = borrow { + let _lock = LockGIL::new(); + + match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) { + Ok(res) => match res { + Ok(()) => 0, + Err(PyTraverseError(value)) => value, + }, + Err(_err) => -1, + } + } else { + 0 + }; + trap.disarm(); + retval } pub(crate) struct PyMethodDefDestructor { diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index ffef87ee2fe..88d3ce2a1b9 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -45,6 +45,7 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/invalid_pymethod_names.rs"); t.compile_fail("tests/ui/invalid_pymodule_args.rs"); t.compile_fail("tests/ui/reject_generics.rs"); + t.compile_fail("tests/ui/traverse_bare_self.rs"); tests_rust_1_49(&t); tests_rust_1_56(&t); diff --git a/tests/test_gc.rs b/tests/test_gc.rs index 8d612eb9c60..c84d6784633 100644 --- a/tests/test_gc.rs +++ b/tests/test_gc.rs @@ -4,6 +4,7 @@ use pyo3::class::PyTraverseError; use pyo3::class::PyVisit; use pyo3::prelude::*; use pyo3::{py_run, AsPyPointer, PyCell, PyTryInto}; +use std::cell::Cell; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -248,22 +249,10 @@ impl TraversableClass { } } -unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option { - std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse)) -} - #[test] fn gc_during_borrow() { Python::with_gil(|py| { unsafe { - // declare a dummy visitor function - extern "C" fn novisit( - _object: *mut pyo3::ffi::PyObject, - _arg: *mut core::ffi::c_void, - ) -> std::os::raw::c_int { - 0 - } - // get the traverse function let ty = py.get_type::().as_type_ptr(); let traverse = get_type_traverse(ty).unwrap(); @@ -290,18 +279,18 @@ fn gc_during_borrow() { } #[pyclass] -struct PanickyTraverse { +struct PartialTraverse { member: PyObject, } -impl PanickyTraverse { +impl PartialTraverse { fn new(py: Python<'_>) -> Self { Self { member: py.None() } } } #[pymethods] -impl PanickyTraverse { +impl PartialTraverse { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { visit.call(&self.member)?; // In the test, we expect this to never be hit @@ -310,25 +299,232 @@ impl PanickyTraverse { } #[test] -fn traverse_error() { +fn traverse_partial() { Python::with_gil(|py| unsafe { - // declare a visitor function which errors (returns nonzero code) - extern "C" fn visit_error( - _object: *mut pyo3::ffi::PyObject, - _arg: *mut core::ffi::c_void, - ) -> std::os::raw::c_int { - -1 - } - // get the traverse function - let ty = py.get_type::().as_type_ptr(); + let ty = py.get_type::().as_type_ptr(); let traverse = get_type_traverse(ty).unwrap(); // confirm that traversing errors - let obj = Py::new(py, PanickyTraverse::new(py)).unwrap(); + let obj = Py::new(py, PartialTraverse::new(py)).unwrap(); assert_eq!( traverse(obj.as_ptr(), visit_error, std::ptr::null_mut()), -1 ); }) } + +#[pyclass] +struct PanickyTraverse { + member: PyObject, +} + +impl PanickyTraverse { + fn new(py: Python<'_>) -> Self { + Self { member: py.None() } + } +} + +#[pymethods] +impl PanickyTraverse { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.member)?; + panic!("at the disco"); + } +} + +#[test] +fn traverse_panic() { + Python::with_gil(|py| unsafe { + // get the traverse function + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + // confirm that traversing errors + let obj = Py::new(py, PanickyTraverse::new(py)).unwrap(); + assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1); + }) +} + +#[pyclass] +struct TriesGILInTraverse {} + +#[pymethods] +impl TriesGILInTraverse { + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + Python::with_gil(|_py| Ok(())) + } +} + +#[test] +fn tries_gil_in_traverse() { + Python::with_gil(|py| unsafe { + // get the traverse function + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + // confirm that traversing panicks + let obj = Py::new(py, TriesGILInTraverse {}).unwrap(); + assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1); + }) +} + +#[pyclass] +struct HijackedTraverse { + traversed: Cell, + hijacked: Cell, +} + +impl HijackedTraverse { + fn new() -> Self { + Self { + traversed: Cell::new(false), + hijacked: Cell::new(false), + } + } + + fn traversed_and_hijacked(&self) -> (bool, bool) { + (self.traversed.get(), self.hijacked.get()) + } +} + +#[pymethods] +impl HijackedTraverse { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.traversed.set(true); + Ok(()) + } +} + +trait Traversable { + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError>; +} + +impl<'a> Traversable for PyRef<'a, HijackedTraverse> { + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.hijacked.set(true); + Ok(()) + } +} + +#[test] +fn traverse_cannot_be_hijacked() { + Python::with_gil(|py| unsafe { + // get the traverse function + let ty = py.get_type::().as_type_ptr(); + let traverse = get_type_traverse(ty).unwrap(); + + let cell = PyCell::new(py, HijackedTraverse::new()).unwrap(); + let obj = cell.to_object(py); + assert_eq!(cell.borrow().traversed_and_hijacked(), (false, false)); + traverse(obj.as_ptr(), novisit, std::ptr::null_mut()); + assert_eq!(cell.borrow().traversed_and_hijacked(), (true, false)); + }) +} + +#[allow(dead_code)] +#[pyclass] +struct DropDuringTraversal { + cycle: Cell>>, + dropped: TestDropCall, +} + +#[pymethods] +impl DropDuringTraversal { + #[allow(clippy::unnecessary_wraps)] + fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + self.cycle.take(); + Ok(()) + } + + fn __clear__(&mut self) { + self.cycle.take(); + } +} + +#[test] +fn drop_during_traversal_with_gil() { + let drop_called = Arc::new(AtomicBool::new(false)); + + Python::with_gil(|py| { + let inst = Py::new( + py, + DropDuringTraversal { + cycle: Cell::new(None), + dropped: TestDropCall { + drop_called: Arc::clone(&drop_called), + }, + }, + ) + .unwrap(); + + inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py))); + + drop(inst); + }); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + Python::with_gil(|py| { + py.run("import gc; gc.collect()", None, None).unwrap(); + }); + } + assert!(drop_called.load(Ordering::Relaxed)); +} + +#[test] +fn drop_during_traversal_without_gil() { + let drop_called = Arc::new(AtomicBool::new(false)); + + let inst = Python::with_gil(|py| { + let inst = Py::new( + py, + DropDuringTraversal { + cycle: Cell::new(None), + dropped: TestDropCall { + drop_called: Arc::clone(&drop_called), + }, + }, + ) + .unwrap(); + + inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py))); + + inst + }); + + drop(inst); + + // due to the internal GC mechanism, we may need multiple + // (but not too many) collections to get `inst` actually dropped. + for _ in 0..10 { + Python::with_gil(|py| { + py.run("import gc; gc.collect()", None, None).unwrap(); + }); + } + assert!(drop_called.load(Ordering::Relaxed)); +} + +// Manual traversal utilities + +unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option { + std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse)) +} + +// a dummy visitor function +extern "C" fn novisit( + _object: *mut pyo3::ffi::PyObject, + _arg: *mut core::ffi::c_void, +) -> std::os::raw::c_int { + 0 +} + +// a visitor function which errors (returns nonzero code) +extern "C" fn visit_error( + _object: *mut pyo3::ffi::PyObject, + _arg: *mut core::ffi::c_void, +) -> std::os::raw::c_int { + -1 +} diff --git a/tests/ui/traverse_bare_self.rs b/tests/ui/traverse_bare_self.rs new file mode 100644 index 00000000000..5adc316e43f --- /dev/null +++ b/tests/ui/traverse_bare_self.rs @@ -0,0 +1,12 @@ +use pyo3::prelude::*; +use pyo3::PyVisit; + +#[pyclass] +struct TraverseTriesToTakePyRef {} + +#[pymethods] +impl TraverseTriesToTakePyRef { + fn __traverse__(slf: PyRef, visit: PyVisit) {} +} + +fn main() {} diff --git a/tests/ui/traverse_bare_self.stderr b/tests/ui/traverse_bare_self.stderr new file mode 100644 index 00000000000..aba76145dc3 --- /dev/null +++ b/tests/ui/traverse_bare_self.stderr @@ -0,0 +1,17 @@ +error[E0308]: mismatched types + --> tests/ui/traverse_bare_self.rs:8:6 + | +7 | #[pymethods] + | ------------ arguments to this function are incorrect +8 | impl TraverseTriesToTakePyRef { + | ______^ +9 | | fn __traverse__(slf: PyRef, visit: PyVisit) {} + | |___________________^ expected fn pointer, found fn item + | + = note: expected fn pointer `for<'a, 'b> fn(&'a TraverseTriesToTakePyRef, PyVisit<'b>) -> Result<(), PyTraverseError>` + found fn item `for<'a, 'b> fn(pyo3::PyRef<'a, TraverseTriesToTakePyRef>, PyVisit<'b>) {TraverseTriesToTakePyRef::__traverse__}` +note: function defined here + --> src/impl_/pymethods.rs + | + | pub unsafe fn call_traverse_impl( + | ^^^^^^^^^^^^^^^^^^