diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index b8f74bacb00d..ee151db7166c 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -142,6 +142,10 @@ def empty( func = self._get_cached_method("runtime.disco.empty") return func(ShapeTuple(shape), dtype, device) + def shutdown(self): + """Shut down the Disco session""" + _ffi_api.SessionShutdown(self) # type: ignore # pylint: disable=no-member + def get_global_func(self, name: str) -> DRef: """Get a global function on workers. diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index 12339c4fa58c..e74d3819fe04 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -52,6 +52,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.SessionCallPacked").set_body([](TVMArgs args, *rv = SessionObj::FFI::CallWithPacked( self, TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1)); }); +TVM_REGISTER_GLOBAL("runtime.disco.SessionShutdown") + .set_body_method(&SessionObj::Shutdown); } // namespace runtime } // namespace tvm