diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index eccdcbad9520..3a68c567eef6 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -40,3 +40,4 @@ ) from . import executor +from . import disco diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index b166bd82e9e5..c54f646e17ce 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -21,7 +21,7 @@ import numpy as np -from ..._ffi import register_object +from ..._ffi import register_object, register_func from ..._ffi.runtime_ctypes import Device from ..container import ShapeTuple from ..ndarray import NDArray @@ -153,6 +153,23 @@ def get_global_func(self, name: str) -> DRef: """ return DPackedFunc(_ffi_api.SessionGetGlobalFunc(self, name), self) # type: ignore # pylint: disable=no-member + def import_python_module(self, module_name: str) -> None: + """Import a python module in each worker + + This may be required before call + + Parameters + ---------- + module_name: str + + The python module name, as it would be used in a python + `import` statement. + """ + if not hasattr(self, "_import_python_module"): + self._import_python_module = self.get_global_func("runtime.disco._import_python_module") + + self._import_python_module(module_name) + def call_packed(self, func: DRef, *args) -> DRef: """Call a PackedFunc on workers providing variadic arguments. @@ -369,6 +386,11 @@ def __init__(self, num_workers: int, entrypoint: str) -> None: ) +@register_func("runtime.disco._import_python_module") +def _import_python_module(module_name: str) -> None: + __import__(module_name) + + REDUCE_OPS = { "sum": 0, "prod": 1,