From b63292feffc435adf09512ada3c091c7db0f62d9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Feb 2024 21:49:32 +0000 Subject: [PATCH] [Disco] Implement `Session.import_python_module` method Import a module into the workers. If a python module has not yet been loaded, `Session.get_global_func` cannot load a packed func from it. --- python/tvm/runtime/__init__.py | 1 + python/tvm/runtime/disco/session.py | 24 +++++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index eccdcbad9520..3a68c567eef6 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -40,3 +40,4 @@ ) from . import executor +from . import disco diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index b166bd82e9e5..c54f646e17ce 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -21,7 +21,7 @@ import numpy as np -from ..._ffi import register_object +from ..._ffi import register_object, register_func from ..._ffi.runtime_ctypes import Device from ..container import ShapeTuple from ..ndarray import NDArray @@ -153,6 +153,23 @@ def get_global_func(self, name: str) -> DRef: """ return DPackedFunc(_ffi_api.SessionGetGlobalFunc(self, name), self) # type: ignore # pylint: disable=no-member + def import_python_module(self, module_name: str) -> None: + """Import a python module in each worker + + This may be required before call + + Parameters + ---------- + module_name: str + + The python module name, as it would be used in a python + `import` statement. + """ + if not hasattr(self, "_import_python_module"): + self._import_python_module = self.get_global_func("runtime.disco._import_python_module") + + self._import_python_module(module_name) + def call_packed(self, func: DRef, *args) -> DRef: """Call a PackedFunc on workers providing variadic arguments. @@ -369,6 +386,11 @@ def __init__(self, num_workers: int, entrypoint: str) -> None: ) +@register_func("runtime.disco._import_python_module") +def _import_python_module(module_name: str) -> None: + __import__(module_name) + + REDUCE_OPS = { "sum": 0, "prod": 1,