From 255948f34a126c8b1cc726c401fa4973de2493e0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 8 May 2024 14:54:15 +0000 Subject: [PATCH] [Disco] Expose disco.Session.shutdown through the python API Prior to this commit, the `SessionObj::Shutdown` method could be called from the C++ API, but could not be called through the Python API. While it is implicitly called when the `SessionObj` is destructed, Python's garbage collection may result in the destruction occurring later than expected. This commit exposes `SessionObj::Shutdown` through the Python API as `disco.Session.shutdown`, allowing it to be closed cleanly. --- python/tvm/runtime/disco/session.py | 4 ++++ src/runtime/disco/session.cc | 2 ++ 2 files changed, 6 insertions(+) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index b8f74bacb00d..ee151db7166c 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -142,6 +142,10 @@ def empty( func = self._get_cached_method("runtime.disco.empty") return func(ShapeTuple(shape), dtype, device) + def shutdown(self): + """Shut down the Disco session""" + _ffi_api.SessionShutdown(self) # type: ignore # pylint: disable=no-member + def get_global_func(self, name: str) -> DRef: """Get a global function on workers. diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index 12339c4fa58c..e74d3819fe04 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -52,6 +52,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked").set_body([](TVMArgs args, *rv = SessionObj::FFI::CallWithPacked( self, TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1)); }); +TVM_REGISTER_GLOBAL("runtime.disco.SessionShutdown") + .set_body_method(&SessionObj::Shutdown); } // namespace runtime } // namespace tvm