diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index 1013d14a89c1..53b362f57983 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -17,6 +17,11 @@ """This module defines a Session in Disco. Session is the primary interface that users interact with the distributed runtime. """ + +import os +import pickle + + from typing import Any, Callable, Optional, Sequence, Union import numpy as np @@ -384,6 +389,42 @@ def __init__(self, num_workers: int, entrypoint: str = "tvm.exec.disco_worker") "runtime.disco.create_process_pool", entrypoint, ) + self._configure_structlog() + + def _configure_structlog(self) -> None: + try: + import structlog # pylint: disable=import-outside-toplevel + except ImportError: + return + + config = pickle.dumps(structlog.get_config()) + func = self.get_global_func("runtime.disco._configure_structlog") + func(config, os.getpid()) + + +@register_func("runtime.disco._configure_structlog") +def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None: + """Configure structlog for all disco workers + + The child processes + + Parameters + ---------- + pickled_config: bytes + + The pickled configuration for structlog + + parent_pid: int + + The PID of the main process. This is used to restrict the + """ + if os.getpid() == parent_pid: + return + + import structlog # pylint: disable=import-outside-toplevel + + config = pickle.loads(pickled_config) + structlog.configure(**config) @register_func("runtime.disco._import_python_module")