Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 199 additions & 34 deletions tests/test_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::class::PyTraverseError;
use pyo3::class::PyVisit;
use pyo3::prelude::*;
use pyo3::{py_run, AsPyPointer, PyCell, PyTryInto};
use std::cell::Cell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand Down Expand Up @@ -248,22 +249,10 @@ impl TraversableClass {
}
}

unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse))
}

#[test]
fn gc_during_borrow() {
Python::with_gil(|py| {
unsafe {
// declare a dummy visitor function
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// get the traverse function
let ty = py.get_type::<TraversableClass>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();
Expand All @@ -290,18 +279,18 @@ fn gc_during_borrow() {
}

#[pyclass]
struct PanickyTraverse {
struct PartialTraverse {
member: PyObject,
}

impl PanickyTraverse {
impl PartialTraverse {
fn new(py: Python<'_>) -> Self {
Self { member: py.None() }
}
}

#[pymethods]
impl PanickyTraverse {
impl PartialTraverse {
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.member)?;
// In the test, we expect this to never be hit
Expand All @@ -310,29 +299,53 @@ impl PanickyTraverse {
}

#[test]
fn traverse_error() {
fn traverse_partial() {
Python::with_gil(|py| unsafe {
// declare a visitor function which errors (returns nonzero code)
extern "C" fn visit_error(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
-1
}

// get the traverse function
let ty = py.get_type::<PanickyTraverse>().as_type_ptr();
let ty = py.get_type::<PartialTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

// confirm that traversing errors
let obj = Py::new(py, PanickyTraverse::new(py)).unwrap();
let obj = Py::new(py, PartialTraverse::new(py)).unwrap();
assert_eq!(
traverse(obj.as_ptr(), visit_error, std::ptr::null_mut()),
-1
);
})
}

#[pyclass]
struct PanickyTraverse {
member: PyObject,
}

impl PanickyTraverse {
fn new(py: Python<'_>) -> Self {
Self { member: py.None() }
}
}

#[pymethods]
impl PanickyTraverse {
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
visit.call(&self.member)?;
panic!("at the disco");
}
}

#[test]
fn traverse_panic() {
Python::with_gil(|py| unsafe {
// get the traverse function
let ty = py.get_type::<PanickyTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

// confirm that traversing errors
let obj = Py::new(py, PanickyTraverse::new(py)).unwrap();
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
})
}

#[pyclass]
struct TriesGILInTraverse {}

Expand All @@ -346,14 +359,6 @@ impl TriesGILInTraverse {
#[test]
fn tries_gil_in_traverse() {
Python::with_gil(|py| unsafe {
// declare a visitor function which errors (returns nonzero code)
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// get the traverse function
let ty = py.get_type::<TriesGILInTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();
Expand All @@ -363,3 +368,163 @@ fn tries_gil_in_traverse() {
assert_eq!(traverse(obj.as_ptr(), novisit, std::ptr::null_mut()), -1);
})
}

#[pyclass]
struct HijackedTraverse {
traversed: Cell<bool>,
hijacked: Cell<bool>,
}

impl HijackedTraverse {
fn new() -> Self {
Self {
traversed: Cell::new(false),
hijacked: Cell::new(false),
}
}

fn traversed_and_hijacked(&self) -> (bool, bool) {
(self.traversed.get(), self.hijacked.get())
}
}

#[pymethods]
impl HijackedTraverse {
#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.traversed.set(true);
Ok(())
}
}

trait Traversable {
fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError>;
}

impl<'a> Traversable for PyRef<'a, HijackedTraverse> {
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.hijacked.set(true);
Ok(())
}
}

#[test]
fn traverse_cannot_be_hijacked() {
Python::with_gil(|py| unsafe {
// get the traverse function
let ty = py.get_type::<HijackedTraverse>().as_type_ptr();
let traverse = get_type_traverse(ty).unwrap();

let cell = PyCell::new(py, HijackedTraverse::new()).unwrap();
let obj = cell.to_object(py);
assert_eq!(cell.borrow().traversed_and_hijacked(), (false, false));
traverse(obj.as_ptr(), novisit, std::ptr::null_mut());
assert_eq!(cell.borrow().traversed_and_hijacked(), (true, false));
})
}

#[allow(dead_code)]
#[pyclass]
struct DropDuringTraversal {
cycle: Cell<Option<Py<Self>>>,
dropped: TestDropCall,
}

#[pymethods]
impl DropDuringTraversal {
#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.cycle.take();
Ok(())
}

fn __clear__(&mut self) {
self.cycle.take();
}
}

#[test]
fn drop_during_traversal_with_gil() {
let drop_called = Arc::new(AtomicBool::new(false));

Python::with_gil(|py| {
let inst = Py::new(
py,
DropDuringTraversal {
cycle: Cell::new(None),
dropped: TestDropCall {
drop_called: Arc::clone(&drop_called),
},
},
)
.unwrap();

inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));

drop(inst);
});

// due to the internal GC mechanism, we may need multiple
// (but not too many) collections to get `inst` actually dropped.
for _ in 0..10 {
Python::with_gil(|py| {
py.run("import gc; gc.collect()", None, None).unwrap();
});
}
assert!(drop_called.load(Ordering::Relaxed));
}

#[test]
fn drop_during_traversal_without_gil() {
let drop_called = Arc::new(AtomicBool::new(false));

let inst = Python::with_gil(|py| {
let inst = Py::new(
py,
DropDuringTraversal {
cycle: Cell::new(None),
dropped: TestDropCall {
drop_called: Arc::clone(&drop_called),
},
},
)
.unwrap();

inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));

inst
});

drop(inst);

// due to the internal GC mechanism, we may need multiple
// (but not too many) collections to get `inst` actually dropped.
for _ in 0..10 {
Python::with_gil(|py| {
py.run("import gc; gc.collect()", None, None).unwrap();
});
}
assert!(drop_called.load(Ordering::Relaxed));
}

// Manual traversal utilities

unsafe fn get_type_traverse(tp: *mut pyo3::ffi::PyTypeObject) -> Option<pyo3::ffi::traverseproc> {
std::mem::transmute(pyo3::ffi::PyType_GetSlot(tp, pyo3::ffi::Py_tp_traverse))
}

// a dummy visitor function
extern "C" fn novisit(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
0
}

// a visitor function which errors (returns nonzero code)
extern "C" fn visit_error(
_object: *mut pyo3::ffi::PyObject,
_arg: *mut core::ffi::c_void,
) -> std::os::raw::c_int {
-1
}