diff --git a/news/856.bugfix.rst b/news/856.bugfix.rst new file mode 100644 index 0000000000..b4ecdadd72 --- /dev/null +++ b/news/856.bugfix.rst @@ -0,0 +1 @@ +Fix an ``AttributeError`` crash on shutdown when tracking a script that calls ``gevent.monkey.patch_all()`` (or otherwise replaces ``threading.Thread`` while tracking). The ``Tracker`` now records the ``threading.Thread`` subclass it patched on entry and removes its instrumentation from that same class on exit, instead of looking up ``threading.Thread`` again. diff --git a/src/memray/_memray.pyx b/src/memray/_memray.pyx index d5702ecf9f..e2663b9293 100644 --- a/src/memray/_memray.pyx +++ b/src/memray/_memray.pyx @@ -755,6 +755,7 @@ cdef class Tracker: cdef bool _trace_python_allocators cdef object _previous_profile_func cdef object _previous_thread_profile_func + cdef object _patched_thread_class cdef unique_ptr[RecordWriter] _writer cdef object _surviving_objects @@ -832,21 +833,22 @@ cdef class Tracker: raise RuntimeError("Attempting to use stale output handle") writer = move(self._writer) + self._patched_thread_class = threading.Thread for attr in ("_name", "_ident"): - assert not hasattr(threading.Thread, attr) + assert not hasattr(self._patched_thread_class, attr) setattr( - threading.Thread, + self._patched_thread_class, attr, ThreadNameInterceptor(attr, NativeTracker.registerThreadNameById), ) - orig_set_os_name = getattr(threading.Thread, "_set_os_name", None) + orig_set_os_name = getattr(self._patched_thread_class, "_set_os_name", None) if orig_set_os_name is not None: def set_os_name_wrapper(self): cdef unique_ptr[RecursionGuard] guard = make_unique[RecursionGuard]() orig_set_os_name(self) - setattr(threading.Thread, "_set_os_name", set_os_name_wrapper) + setattr(self._patched_thread_class, "_set_os_name", set_os_name_wrapper) self._previous_profile_func = sys.getprofile() self._previous_thread_profile_func = threading._profile_hook @@ -875,7 +877,11 @@ cdef class Tracker: threading.setprofile(self._previous_thread_profile_func) for attr in ("_name", "_ident"): - delattr(threading.Thread, attr) + try: + delattr(self._patched_thread_class, attr) + except AttributeError: + pass + self._patched_thread_class = None cdef void _populate_surviving_objects(self): cdef NativeTracker *tracker = NativeTracker.getTracker() diff --git a/tests/unit/test_tracker.py b/tests/unit/test_tracker.py index b769c892bc..13bb956969 100644 --- a/tests/unit/test_tracker.py +++ b/tests/unit/test_tracker.py @@ -35,6 +35,23 @@ def test_the_same_tracker_cannot_be_activated_twice(tmpdir): pass +def test_thread_class_swap_during_tracking_does_not_crash(tmpdir, monkeypatch): + """Regression test for https://github.com/bloomberg/memray/issues/856. + + ``gevent.monkey.patch_all()`` replaces ``threading.Thread`` with its own + class while tracking is active. ``Tracker.__exit__`` must clean up the + instrumentation it installed on the original ``Thread`` class, not the + one that ``threading.Thread`` happens to point at on exit. + """ + output = Path(tmpdir) / "test.bin" + with Tracker(output): + + class FakeThread: + pass + + monkeypatch.setattr("threading.Thread", FakeThread) + + @skip_if_macos def test_rss_from_proc_status_includes_hugetlb_pages(): assert (