diff --git a/pymllm/orchestrator/scheduler_process.py b/pymllm/orchestrator/scheduler_process.py index 3bc3466a..d085266d 100644 --- a/pymllm/orchestrator/scheduler_process.py +++ b/pymllm/orchestrator/scheduler_process.py @@ -29,6 +29,7 @@ """ import logging +import os import queue as stdlib_queue import time from collections import deque @@ -55,6 +56,27 @@ _DEFAULT_MAX_TOTAL_TOKENS = 131072 _DEFAULT_MAX_NEW_TOKENS = 32768 +# Brief poll timeout (ms) used between decode batches to avoid 100% CPU spin. +# 1 ms is enough to yield the CPU core to the OS scheduler while adding +# negligible latency (decode steps typically take >1 ms on the GPU anyway). +# Override via MLLM_DECODE_POLL_TIMEOUT_MS env var for testing. +def _read_decode_poll_timeout_ms() -> int: + raw = os.environ.get("MLLM_DECODE_POLL_TIMEOUT_MS", "1") + try: + val = int(raw) + except ValueError: + raise ValueError( + f"MLLM_DECODE_POLL_TIMEOUT_MS must be a non-negative integer, got {raw!r}" + ) + if val < 0: + raise ValueError( + f"MLLM_DECODE_POLL_TIMEOUT_MS must be >= 0, got {val}" + ) + return val + + +_DECODE_POLL_TIMEOUT_MS = _read_decode_poll_timeout_ms() + # ====================================================================== # IdleSleeper -- avoid busy-looping when no work is available @@ -482,20 +504,30 @@ def init_model(self) -> None: logger.info("In-process model runner initialised on GPU %d", self._gpu_id) def event_loop(self) -> None: - """Infinite scheduling loop.""" + """Infinite scheduling loop. + + When decode batches are active the loop would otherwise spin at + 100 % CPU doing non-blocking ZMQ polls between GPU forward passes. + We track whether the previous iteration ran a decode batch and, if + so, use a brief poll timeout (default 1 ms) in ``recv_requests`` + so the OS can schedule other work on this core. + """ logger.info( "SchedulerProcess event loop started (shared_queue=%s, transport=%s)", self._enable_shared_queue, self._tensor_transport_mode, ) + _in_decode = False while True: - self.recv_requests() + self.recv_requests(brief_poll=_in_decode) self.process_input_requests() batch = self.get_next_batch_to_run() if batch is not None: + _in_decode = not batch.forward_mode.is_extend() result = self.run_batch(batch) self.process_batch_result(batch, result) else: + _in_decode = False # No work available -- sleep until a new request arrives # on the ZMQ socket (or timeout). Avoids busy-looping. self._idle_sleeper.sleep() @@ -505,30 +537,40 @@ def event_loop(self) -> None: # Step 1: receive tokenized requests (non-blocking) # ------------------------------------------------------------------ - def recv_requests(self) -> None: + def recv_requests(self, brief_poll: bool = False) -> None: """Non-blocking receive of tokenized requests from TokenizerProcess. Supports two modes: 1. Legacy ZMQ: Uses ``zmq.Poller`` with a short timeout 2. Shared queue: Non-blocking get from multiprocessing.Queue + When *brief_poll* is ``True`` (typically during active decode), the + first poll uses a small timeout (``_DECODE_POLL_TIMEOUT_MS``) instead + of zero. This yields the CPU core to the OS scheduler between decode + batches while adding negligible latency. + Messages are either: * A :class:`~pymllm.engine.io_struct.TokenizedGenerateReqInput` - dataclass – appended to ``_waiting_queue``. - * A plain abort sentinel dict ``{"rid": ..., "abort": True}`` – handled + dataclass - appended to ``_waiting_queue``. + * A plain abort sentinel dict ``{"rid": ..., "abort": True}`` - handled inline by removing the matching rid from the waiting queue. """ if self._enable_shared_queue and self._shared_queue is not None: - self._recv_from_shared_queue() + self._recv_from_shared_queue(brief_poll=brief_poll) else: - self._recv_from_zmq() + self._recv_from_zmq(brief_poll=brief_poll) - def _recv_from_zmq(self) -> None: + def _recv_from_zmq(self, brief_poll: bool = False) -> None: """Receive requests via legacy ZMQ path.""" + # On the first poll, use a brief timeout if requested (decode path) + # to yield the CPU. After draining the first message, switch to + # non-blocking for any remaining queued messages. + poll_timeout = _DECODE_POLL_TIMEOUT_MS if brief_poll else 0 while True: - events = dict(self._poller.poll(timeout=0)) # non-blocking + events = dict(self._poller.poll(timeout=poll_timeout)) if self._recv_from_tokenizer not in events: break + poll_timeout = 0 # drain remaining messages without blocking msg = self._recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) # Abort sentinel: plain dict with "abort" key. if isinstance(msg, dict) and msg.get("abort"): @@ -542,7 +584,7 @@ def _recv_from_zmq(self) -> None: else: self._waiting_queue.append(msg) - def _recv_from_shared_queue(self) -> None: + def _recv_from_shared_queue(self, brief_poll: bool = False) -> None: """Receive requests via shared memory + shared queue fast path. After reading a ``(rid, shm_name, mm_inputs)`` tuple from the queue: @@ -556,9 +598,13 @@ def _recv_from_shared_queue(self) -> None: 3. A full ``TokenizedGenerateReqInput`` is assembled and appended to ``_waiting_queue``. """ + # Use a slightly longer timeout on the first get when in decode mode + # to yield CPU; subsequent gets use a short timeout to drain the queue. + get_timeout = _DECODE_POLL_TIMEOUT_MS / 1000.0 if brief_poll else 0.002 while True: try: - rid, shm_name, mm_inputs = self._shared_queue.get(timeout=0.002) + rid, shm_name, mm_inputs = self._shared_queue.get(timeout=get_timeout) + get_timeout = 0.002 # drain remaining without extra delay # Read metadata from shared memory (and unlink immediately) metadata: TokenizedGenerateReqInput = SharedMemoryManager.read_metadata(