From a8831ca24e4118666585c53190e7eab9348a53a9 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 03:12:47 +0000 Subject: [PATCH 1/2] fmq_a2e --- fastdeploy/engine/common_engine.py | 43 +++++++++++++++++--- fastdeploy/entrypoints/engine_client.py | 27 +++++++----- fastdeploy/entrypoints/openai/api_server.py | 3 -- fastdeploy/inter_communicator/__init__.py | 4 ++ fastdeploy/inter_communicator/fmq.py | 37 +++++++++++------ fastdeploy/inter_communicator/fmq_factory.py | 33 ++++++++++----- 6 files changed, 104 insertions(+), 43 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index bc3dbd78d05..9fb475cf4d6 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import copy import json import multiprocessing @@ -46,6 +47,7 @@ from fastdeploy.inter_communicator import ( EngineCacheQueue, EngineWorkerQueue, + FMQFactory, IPCSignal, ZmqIpcServer, ZmqTcpServer, @@ -1048,20 +1050,28 @@ def start_zmq_service(self, api_server_pid=None): cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id ) else: - self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) + self.fmq_a2e_consumer = None self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) self.recv_result_handle_thread = threading.Thread( target=self.send_response_server.recv_result_handle, daemon=True ) self.recv_result_handle_thread.start() time.sleep(3) - self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True) + self.insert_task_to_scheduler_thread = threading.Thread( + target=self._run_insert_zmq_task_to_scheduler, daemon=True + ) self.insert_task_to_scheduler_thread.start() self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True) self.receive_output_thread.start() - def _insert_zmq_task_to_scheduler(self): + def _run_insert_zmq_task_to_scheduler(self): + try: + asyncio.run(self._insert_zmq_task_to_scheduler()) + except Exception as e: + self.llm_logger.error(f"Async loop crashed: {e}") + + async def _insert_zmq_task_to_scheduler(self): tracing.trace_set_thread_info("Insert Task to Scheduler") added_requests: Dict[str, int] = dict() if envs.FD_ENABLE_INTERNAL_ADAPTER: @@ -1071,10 +1081,22 @@ def _insert_zmq_task_to_scheduler(self): while self.running: try: block = True if len(added_requests) == 0 else False - if not self.cfg.model_config.enable_mm: - err, data = self.recv_request_server.receive_json_once(block) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + if not self.cfg.model_config.enable_mm: + err, data = self.recv_request_server.receive_json_once(block) + else: + err, data = self.recv_request_server.receive_pyobj_once(block) else: - err, data = self.recv_request_server.receive_pyobj_once(block) + err = None + if self.fmq_a2e_consumer is None: + self.fmq_a2e_consumer = FMQFactory.q_a2e_consumer() + try: + msg = await self.fmq_a2e_consumer.get() + if msg is None: + continue + data = msg.payload + except Exception as e: + err = e if err is not None: # The message "Context was terminated" is normal when closing a ZMQ context if "Context was terminated" in str(err): @@ -1516,6 +1538,15 @@ def _exit_sub_services(self): self.recv_request_server.close() if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: self.recv_control_cmd_server.close() + if hasattr(self, "fmq_a2e_consumer") and self.fmq_a2e_consumer is not None: + try: + if hasattr(self.fmq_a2e_consumer, "socket") and self.fmq_consumer.socket is not None: + self.fmq_a2e_consumer.socket.close() + llm_logger.info("FMQ consumer socket closed successfully.") + except Exception as e: + llm_logger.error(f"Error closing fmq_consumer: {e}") + finally: + self.fmq_a2e_consumer = None # 从 async_llm 移到 common_engine def _worker_processes_ready(self): diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 9babe8fec74..ec33da56324 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -33,6 +33,7 @@ from fastdeploy.eplb.utils import RedundantExpertWorkload from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import ( + FMQFactory, IPCSignal, KVCacheStatus, ModelWeightsStatus, @@ -49,7 +50,6 @@ ParameterError, StatefulSemaphore, api_server_logger, - to_tensor, ) @@ -82,6 +82,7 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed" self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + self.fmq_a2e_producer = None if self.tensor_parallel_size <= self.max_chips_per_node: self.is_master = True @@ -348,7 +349,7 @@ async def add_requests(self, task): request_id_idx = task.get("request_id") parts = request_id_idx.rsplit("_", 1) if len(parts) == 1: - self._send_task(task) + await self._send_task(task) else: request_id = parts[0] index = int(parts[1]) @@ -357,21 +358,25 @@ async def add_requests(self, task): for i in range(index * n, (index + 1) * n): child_task = copy(task) child_task["request_id"] = f"{request_id}_{i}" - self._send_task(child_task) + await self._send_task(child_task) tracing.trace_slice_end( tracing.TraceSpanName.PREPROCESSING, task.get("request_id").split("_")[0], thread_finish_flag=True ) except Exception as e: - api_server_logger.error(f"zmq_client send task error: {e}, {str(traceback.format_exc())}") + api_server_logger.error(f"fmq send task error: {e}, {str(traceback.format_exc())}") raise EngineError(str(e), error_code=400) - def _send_task(self, task): - if not self.enable_mm: - self.zmq_client.send_json(task) - else: - if envs.FD_ENABLE_E2W_TENSOR_CONVERT: - to_tensor([task]) - self.zmq_client.send_pyobj(task) + def _get_producer(self): + if self.fmq_a2e_producer is None: + self.fmq_a2e_producer = FMQFactory.q_a2e_producer() + return self.fmq_a2e_producer + + async def _send_task(self, task): + try: + producer = self._get_producer() + await producer.put(task) + except Exception as e: + api_server_logger.error(f"Failed to send task via FMQ: {e}") def valid_parameters(self, data): """ diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 2744c9388c0..5e10ec7cf30 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -24,7 +24,6 @@ from contextlib import asynccontextmanager import uvicorn -import zmq from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -219,7 +218,6 @@ async def lifespan(app: FastAPI): reward_handler = OpenAIServingReward( engine_client, app.state.model_handler, fd_config, pid, args.ips, args.max_waiting_time, chat_template ) - engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.pid = pid app.state.engine_client = engine_client app.state.chat_handler = chat_handler @@ -233,7 +231,6 @@ async def lifespan(app: FastAPI): # close zmq try: await engine_client.connection_manager.close() - engine_client.zmq_client.close() from prometheus_client import multiprocess multiprocess.mark_process_dead(os.getpid()) diff --git a/fastdeploy/inter_communicator/__init__.py b/fastdeploy/inter_communicator/__init__.py index 6331e06b955..9f75e582852 100644 --- a/fastdeploy/inter_communicator/__init__.py +++ b/fastdeploy/inter_communicator/__init__.py @@ -16,6 +16,8 @@ from .engine_cache_queue import EngineCacheQueue from .engine_worker_queue import EngineWorkerQueue +from .fmq import FMQ +from .fmq_factory import FMQFactory from .ipc_signal import IPCSignal, shared_memory_exists from .ipc_signal_const import ( ExistTaskStatus, @@ -40,4 +42,6 @@ "ModelWeightsStatus", "KVCacheStatus", "RearrangeExpertStatus", + "FMQ", + "FMQFactory", ] diff --git a/fastdeploy/inter_communicator/fmq.py b/fastdeploy/inter_communicator/fmq.py index f2c98196c99..bee59397d80 100644 --- a/fastdeploy/inter_communicator/fmq.py +++ b/fastdeploy/inter_communicator/fmq.py @@ -28,6 +28,8 @@ import zmq.asyncio from fastdeploy import envs +from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.metrics.stats import ZMQMetricsStats from fastdeploy.utils import fmq_logger # ========================== @@ -232,19 +234,28 @@ async def put(self, data: Any, shm_threshold: int = 1024 * 1024): if self.role != Role.PRODUCER: raise PermissionError("Only producers can send messages.") - desc = None - payload = data - - if isinstance(data, bytes) and len(data) >= shm_threshold: - desc = Descriptor.create(data) - payload = None - - msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc) - raw = msg.serialize() - - async with self.lock: - await self.socket.send(raw, copy=self.copy) - self._msg_id += 1 + _zmq_metrics_stats = ZMQMetricsStats() + try: + desc = None + payload = data + + if isinstance(data, bytes) and len(data) >= shm_threshold: + desc = Descriptor.create(data) + payload = None + + msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc, timestamp=time.perf_counter()) + raw = msg.serialize() + _zmq_metrics_stats.msg_bytes_send_total += len(raw) + + async with self.lock: + await self.socket.send(raw, copy=self.copy) + self._msg_id += 1 + except Exception as e: + _zmq_metrics_stats.msg_send_failed_total += 1 + fmq_logger.error(f"Failed to send message: {e}") + finally: + _zmq_metrics_stats.msg_send_total += 1 + main_process_metrics.record_zmq_stats(_zmq_metrics_stats, self.endpoint.address) async def get(self, timeout: int = None) -> Optional[Message]: # Receive data from queue diff --git a/fastdeploy/inter_communicator/fmq_factory.py b/fastdeploy/inter_communicator/fmq_factory.py index d1c8e4dd244..8ed95bbd52b 100644 --- a/fastdeploy/inter_communicator/fmq_factory.py +++ b/fastdeploy/inter_communicator/fmq_factory.py @@ -14,6 +14,8 @@ # limitations under the License. """ +import os + from fastdeploy.inter_communicator.fmq import FMQ @@ -29,55 +31,66 @@ class FMQFactory: Worker: q_e2w consumer / q_w2e producer """ - _fmq = FMQ() + _fmq = None + _pid = None + + @classmethod + def _get_fmq(cls): + current_pid = os.getpid() + if cls._pid != current_pid: + FMQ._instance = None + FMQ._context = None + cls._fmq = FMQ() + cls._pid = current_pid + return cls._fmq # ------------------------------ # API → Engine # ------------------------------ @classmethod def q_a2e_producer(cls): - return cls._fmq.queue("q_a2e", role="producer") + return cls._get_fmq().queue("q_a2e", role="producer") @classmethod def q_a2e_consumer(cls): - return cls._fmq.queue("q_a2e", role="consumer") + return cls._get_fmq().queue("q_a2e", role="consumer") # ------------------------------ # Engine → Worker # ------------------------------ @classmethod def q_e2w_producer(cls): - return cls._fmq.queue("q_e2w", role="producer") + return cls._get_fmq().queue("q_e2w", role="producer") @classmethod def q_e2w_consumer(cls): - return cls._fmq.queue("q_e2w", role="consumer") + return cls._get_fmq().queue("q_e2w", role="consumer") # ------------------------------ # Worker → Engine # ------------------------------ @classmethod def q_w2e_producer(cls): - return cls._fmq.queue("q_w2e", role="producer") + return cls._get_fmq().queue("q_w2e", role="producer") @classmethod def q_w2e_consumer(cls): - return cls._fmq.queue("q_w2e", role="consumer") + return cls._get_fmq().queue("q_w2e", role="consumer") # ------------------------------ # Engine → API # ------------------------------ @classmethod def q_e2a_producer(cls): - return cls._fmq.queue("q_e2a", role="producer") + return cls._get_fmq().queue("q_e2a", role="producer") @classmethod def q_e2a_consumer(cls): - return cls._fmq.queue("q_e2a", role="consumer") + return cls._get_fmq().queue("q_e2a", role="consumer") # ------------------------------ # Destroy context # ------------------------------ @classmethod async def destroy(cls): - await cls._fmq.destroy() + await cls._get_fmq().destroy() From bbaf5de1de149f7796d1fd58114a124d25b4176d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 24 Dec 2025 08:58:36 +0000 Subject: [PATCH 2/2] add unittest --- fastdeploy/engine/async_llm.py | 27 +- fastdeploy/engine/common_engine.py | 2 +- fastdeploy/entrypoints/openai/api_server.py | 12 + tests/engine/test_exit_sub_services_fmq.py | 164 ++++++ .../test_insert_zmq_task_to_scheduler.py | 471 ++++++++++++++++++ tests/engine/test_start_zmq_service_thread.py | 267 ++++++++++ .../test_engine_client_send_task.py | 142 ++++++ 7 files changed, 1079 insertions(+), 6 deletions(-) create mode 100644 tests/engine/test_exit_sub_services_fmq.py create mode 100644 tests/engine/test_insert_zmq_task_to_scheduler.py create mode 100644 tests/engine/test_start_zmq_service_thread.py create mode 100644 tests/entrypoints/test_engine_client_send_task.py diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index edee21af066..dfc812ef179 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -35,7 +35,7 @@ from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.inter_communicator import FMQFactory, IPCSignal from fastdeploy.inter_communicator.zmq_client import ZmqIpcClient from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.utils import EngineError, llm_logger @@ -299,6 +299,7 @@ def __init__(self, cfg, pid): # Create high-performance async connection manager self.connection_manager = None self.request_client = None + self.fmq_a2e_producer = None # Output processor uses data_processor for post-processing engine outputs self.output_processor = AsyncOutputProcessor(self.data_processor) @@ -307,6 +308,11 @@ def __init__(self, cfg, pid): main_process_metrics.set_cache_config_info(obj=self.cfg.cache_config) + def _get_producer(self): + if self.fmq_a2e_producer is None: + self.fmq_a2e_producer = FMQFactory.q_a2e_producer() + return self.fmq_a2e_producer + async def init_connections(self): """Initialize high-performance ZMQ connections""" try: @@ -439,10 +445,11 @@ async def add_request( f"preprocess time cost {preprocess_cost_time}" ) - if not self.cfg.model_config.enable_mm: - self.request_client.send_json(request) - else: - self.request_client.send_pyobj(request) + try: + producer = self._get_producer() + await producer.put(request) + except Exception as e: + llm_logger.error(f"Failed to send task via FMQ: {e}") except EngineError: raise @@ -603,6 +610,16 @@ async def shutdown(self): self.request_client.close() except Exception as e: llm_logger.warning(f"Error closing request client: {e}") + # Close FMQ producer + if hasattr(self, "fmq_a2e_producer") and self.fmq_a2e_producer is not None: + try: + if hasattr(self.fmq_a2e_producer, "socket") and self.fmq_a2e_producer.socket is not None: + self.fmq_a2e_producer.socket.close() + llm_logger.info("FMQ producer socket closed successfully.") + except Exception as e: + llm_logger.error(f"Error closing fmq_producer: {e}") + finally: + self.fmq_a2e_producer = None # Shutdown engine service process try: diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 7f65dbe13f9..9f1f6dc7a4d 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1517,7 +1517,7 @@ def _exit_sub_services(self): self.recv_control_cmd_server.close() if hasattr(self, "fmq_a2e_consumer") and self.fmq_a2e_consumer is not None: try: - if hasattr(self.fmq_a2e_consumer, "socket") and self.fmq_consumer.socket is not None: + if hasattr(self.fmq_a2e_consumer, "socket") and self.fmq_a2e_consumer.socket is not None: self.fmq_a2e_consumer.socket.close() llm_logger.info("FMQ consumer socket closed successfully.") except Exception as e: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index b4ed396079c..fe73241b3f1 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -232,6 +232,18 @@ async def lifespan(app: FastAPI): # close zmq try: await engine_client.connection_manager.close() + if hasattr(engine_client, "fmq_a2e_producer") and engine_client.fmq_a2e_producer is not None: + try: + if ( + hasattr(engine_client.fmq_a2e_producer, "socket") + and engine_client.fmq_a2e_producer.socket is not None + ): + engine_client.fmq_a2e_producer.socket.close() + api_server_logger.info("FMQ producer socket closed successfully.") + except Exception as e: + api_server_logger.error(f"Error closing fmq_producer: {e}") + finally: + engine_client.fmq_a2e_producer = None from prometheus_client import multiprocess multiprocess.mark_process_dead(os.getpid()) diff --git a/tests/engine/test_exit_sub_services_fmq.py b/tests/engine/test_exit_sub_services_fmq.py new file mode 100644 index 00000000000..698ab003dbf --- /dev/null +++ b/tests/engine/test_exit_sub_services_fmq.py @@ -0,0 +1,164 @@ +from unittest.mock import Mock, patch + +import pytest + +from fastdeploy.engine.common_engine import EngineService + + +class TestExitSubServicesFMQ: + """测试 _exit_sub_services 方法中 fmq_a2e_consumer 清理逻辑""" + + @pytest.fixture + def mock_engine_service(self): + """创建模拟的 EngineService 实例""" + engine = Mock(spec=EngineService) + engine.llm_logger = Mock() + engine.running = True + engine.use_async_llm = True + # 添加 _exit_sub_services 方法中需要的 signal 属性 + engine.exist_task_signal = Mock() + engine.exist_swapped_task_signal = Mock() + engine.worker_healthy_live_signal = Mock() + engine.cache_ready_signal = Mock() + engine.swap_space_ready_signal = Mock() + engine.exist_prefill_task_signal = Mock() + engine.model_weights_status_signal = Mock() + engine.prefix_tree_status_signal = Mock() + engine.kv_cache_status_signal = Mock() + engine.worker_ready_signal = Mock() + engine.loaded_model_signal = Mock() + return engine + + @patch("fastdeploy.engine.common_engine.llm_logger") + def test_fmq_a2e_consumer_with_socket_close_success(self, mock_llm_logger, mock_engine_service): + """测试 fmq_a2e_consumer 有 socket 且关闭成功的情况""" + # 模拟 fmq_a2e_consumer 对象,有 socket 属性 + mock_socket = Mock() + mock_fmq_consumer = Mock() + mock_fmq_consumer.socket = mock_socket + mock_engine_service.fmq_a2e_consumer = mock_fmq_consumer + + # 创建真实的函数并调用 + real_function = EngineService._exit_sub_services + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数 + bound_method() + + # 验证 socket.close() 被调用 + mock_socket.close.assert_called_once() + mock_llm_logger.info.assert_any_call("FMQ consumer socket closed successfully.") + # 验证 finally 块中 fmq_a2e_consumer 被设置为 None + assert mock_engine_service.fmq_a2e_consumer is None + + @patch("fastdeploy.engine.common_engine.llm_logger") + def test_fmq_a2e_consumer_with_socket_close_exception(self, mock_llm_logger, mock_engine_service): + """测试 fmq_a2e_consumer 有 socket 但关闭时抛出异常的情况""" + # 模拟 fmq_a2e_consumer 对象,有 socket 属性 + mock_socket = Mock() + mock_socket.close.side_effect = Exception("Socket close failed") + mock_fmq_consumer = Mock() + mock_fmq_consumer.socket = mock_socket + mock_engine_service.fmq_a2e_consumer = mock_fmq_consumer + + # 创建真实的函数并调用 + real_function = EngineService._exit_sub_services + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数 + bound_method() + + # 验证 socket.close() 被调用 + mock_socket.close.assert_called_once() + # 验证异常被捕获并记录 error 日志 + mock_llm_logger.error.assert_called_once() + error_call_args = mock_llm_logger.error.call_args[0][0] + assert "Error closing fmq_consumer: Socket close failed" in error_call_args + # 验证 finally 块中 fmq_a2e_consumer 仍然被设置为 None + assert mock_engine_service.fmq_a2e_consumer is None + + @patch("fastdeploy.engine.common_engine.llm_logger") + def test_fmq_a2e_consumer_no_socket(self, mock_llm_logger, mock_engine_service): + """测试 fmq_a2e_consumer 没有 socket 属性的情况""" + # 模拟 fmq_a2e_consumer 对象,没有 socket 属性 + mock_fmq_consumer = Mock() + # 移除 socket 属性 + del mock_fmq_consumer.socket + mock_engine_service.fmq_a2e_consumer = mock_fmq_consumer + + # 创建真实的函数并调用 + real_function = EngineService._exit_sub_services + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数 + bound_method() + + # 验证没有调用 socket.close() + # 验证没有记录 "FMQ consumer socket closed successfully." 日志 + info_calls = [call[0][0] for call in mock_llm_logger.info.call_args_list] + assert "FMQ consumer socket closed successfully." not in info_calls + # 验证 finally 块中 fmq_a2e_consumer 被设置为 None + assert mock_engine_service.fmq_a2e_consumer is None + + @patch("fastdeploy.engine.common_engine.llm_logger") + def test_fmq_a2e_consumer_socket_none(self, mock_llm_logger, mock_engine_service): + """测试 fmq_a2e_consumer 有 socket 属性但 socket 为 None 的情况""" + # 模拟 fmq_a2e_consumer 对象,socket 为 None + mock_fmq_consumer = Mock() + mock_fmq_consumer.socket = None + mock_engine_service.fmq_a2e_consumer = mock_fmq_consumer + + # 创建真实的函数并调用 + real_function = EngineService._exit_sub_services + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数 + bound_method() + + # 验证没有调用 socket.close() + # 验证没有记录 "FMQ consumer socket closed successfully." 日志 + info_calls = [call[0][0] for call in mock_llm_logger.info.call_args_list] + assert "FMQ consumer socket closed successfully." not in info_calls + # 验证 finally 块中 fmq_a2e_consumer 被设置为 None + assert mock_engine_service.fmq_a2e_consumer is None + + @patch("fastdeploy.engine.common_engine.llm_logger") + def test_fmq_a2e_consumer_none(self, mock_llm_logger, mock_engine_service): + """测试 fmq_a2e_consumer 为 None 的情况""" + # 设置 fmq_a2e_consumer 为 None + mock_engine_service.fmq_a2e_consumer = None + + # 创建真实的函数并调用 + real_function = EngineService._exit_sub_services + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数 + bound_method() + + # 验证 fmq_a2e_consumer 保持为 None + assert mock_engine_service.fmq_a2e_consumer is None + # 验证没有调用任何 socket 相关方法 + info_calls = [call[0][0] for call in mock_llm_logger.info.call_args_list] + assert "FMQ consumer socket closed successfully." not in info_calls + error_calls = [call[0][0] for call in mock_llm_logger.error.call_args_list] + assert not any("Error closing fmq_consumer" in call for call in error_calls) + + @patch("fastdeploy.engine.common_engine.llm_logger") + def test_fmq_a2e_consumer_hasattr_false(self, mock_llm_logger, mock_engine_service): + """测试对象没有 fmq_a2e_consumer 属性的情况""" + # 确保对象没有 fmq_a2e_consumer 属性 + if hasattr(mock_engine_service, "fmq_a2e_consumer"): + delattr(mock_engine_service, "fmq_a2e_consumer") + + # 创建真实的函数并调用 + real_function = EngineService._exit_sub_services + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数 + bound_method() + + # 验证没有调用任何 socket 相关方法 + info_calls = [call[0][0] for call in mock_llm_logger.info.call_args_list] + assert "FMQ consumer socket closed successfully." not in info_calls + error_calls = [call[0][0] for call in mock_llm_logger.error.call_args_list] + assert not any("Error closing fmq_consumer" in call for call in error_calls) diff --git a/tests/engine/test_insert_zmq_task_to_scheduler.py b/tests/engine/test_insert_zmq_task_to_scheduler.py new file mode 100644 index 00000000000..59544eb43e1 --- /dev/null +++ b/tests/engine/test_insert_zmq_task_to_scheduler.py @@ -0,0 +1,471 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from fastdeploy.engine.common_engine import EngineService +from fastdeploy.engine.request import Request, RequestMetrics + + +class TestInsertZmqTaskToScheduler: + """测试 _insert_zmq_task_to_scheduler 函数的单元测试""" + + @pytest.fixture + def mock_engine_service(self): + """创建模拟的 EngineService 实例""" + engine = Mock(spec=EngineService) + engine.running = True + engine.cfg = Mock() + engine.cfg.scheduler_config = Mock() + engine.cfg.scheduler_config.splitwise_role = "prefill" + engine.cfg.model_config = Mock() + engine.cfg.model_config.enable_mm = False + engine.llm_logger = Mock() + engine.scheduler = Mock() + engine.guided_decoding_checker = None + engine.fmq_a2e_consumer = None + + # 模拟 added_requests 字典 + engine.added_requests = {} + + return engine + + @pytest.fixture + def sample_request_data(self): + """示例请求数据""" + return { + "request_id": "test_req_123", + "prompt": "Hello, world!", + "prompt_token_ids": [1, 2, 3, 4], + "prompt_token_ids_len": 4, + "messages": None, + "history": None, + "tools": None, + "system": None, + "eos_token_ids": [2], + "sampling_params": {"temperature": 0.7, "top_p": 0.9, "max_tokens": 100}, + "user": "test_user", + } + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_with_internal_adapter_json( + self, mock_engine_service, sample_request_data + ): + """测试 FD_ENABLE_INTERNAL_ADAPTER=True 且 enable_mm=False 的情况(JSON模式)""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + with patch("fastdeploy.engine.common_engine.Request.from_dict") as mock_from_dict: + with patch("fastdeploy.engine.common_engine.main_process_metrics") as mock_metrics: + with patch("fastdeploy.engine.common_engine.trace_print") as mock_trace: + # 模拟 recv_request_server + mock_recv_server = Mock() + mock_recv_server.receive_json_once.return_value = (None, sample_request_data) + mock_engine_service.recv_request_server = mock_recv_server + + # 模拟 Request.from_dict 返回 + mock_request = Mock(spec=Request) + mock_request.request_id = "test_req_123" + mock_request.metrics = RequestMetrics() + mock_from_dict.return_value = mock_request + mock_trace = mock_trace if mock_trace is not None else mock_trace + + # 模拟 scheduler.put_requests 返回 + mock_engine_service.scheduler.put_requests.return_value = [("test_req_123", None)] + + # 模拟 metrics + mock_metrics.requests_number = Mock() + mock_metrics.num_requests_waiting = Mock() + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 由于这是一个无限循环的函数,我们需要模拟只运行一次 + # 使用 side_effect 来控制循环退出 + call_count = 0 + + def mock_receive_json_once(block): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (None, sample_request_data) + else: + # 模拟 Context was terminated 错误来退出循环 + return (Exception("Context was terminated"), None) + + mock_recv_server.receive_json_once.side_effect = mock_receive_json_once + + # 执行函数 + await bound_method() + + # 验证调用 + mock_recv_server.receive_json_once.assert_called() + mock_from_dict.assert_called_once_with(sample_request_data) + mock_engine_service.scheduler.put_requests.assert_called_once() + mock_metrics.requests_number.inc.assert_called_once() + mock_metrics.num_requests_waiting.inc.assert_called_once() + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_with_internal_adapter_pyobj( + self, mock_engine_service, sample_request_data + ): + """测试 FD_ENABLE_INTERNAL_ADAPTER=True 且 enable_mm=True 的情况(PyObj模式)""" + mock_engine_service.cfg.model_config.enable_mm = True + + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + with patch("fastdeploy.engine.common_engine.Request.from_dict") as mock_from_dict: + with patch("fastdeploy.engine.common_engine.main_process_metrics") as mock_metrics: + with patch("fastdeploy.engine.common_engine.trace_print") as mock_trace: + # 模拟 recv_request_server + mock_recv_server = Mock() + mock_recv_server.receive_pyobj_once.return_value = (None, sample_request_data) + mock_engine_service.recv_request_server = mock_recv_server + + # 模拟 Request.from_dict 返回 + mock_request = Mock(spec=Request) + mock_request.request_id = "test_req_123" + mock_request.metrics = RequestMetrics() + mock_from_dict.return_value = mock_request + mock_trace = mock_trace if mock_trace is not None else mock_trace + # 模拟 scheduler.put_requests 返回 + mock_engine_service.scheduler.put_requests.return_value = [("test_req_123", None)] + + # 模拟 metrics + mock_metrics.requests_number = Mock() + mock_metrics.num_requests_waiting = Mock() + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + call_count = 0 + + def mock_receive_pyobj_once(block): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (None, sample_request_data) + else: + return (Exception("Context was terminated"), None) + + mock_recv_server.receive_pyobj_once.side_effect = mock_receive_pyobj_once + + # 执行函数 + await bound_method() + + # 验证调用 + mock_recv_server.receive_pyobj_once.assert_called() + mock_from_dict.assert_called_once_with(sample_request_data) + mock_engine_service.scheduler.put_requests.assert_called_once() + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_without_internal_adapter( + self, mock_engine_service, sample_request_data + ): + """测试 FD_ENABLE_INTERNAL_ADAPTER=False 的情况(FMQ模式)""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", False): + with patch("fastdeploy.engine.common_engine.FMQFactory") as mock_fmq_factory: + with patch("fastdeploy.engine.common_engine.Request.from_dict") as mock_from_dict: + with patch("fastdeploy.engine.common_engine.main_process_metrics") as mock_metrics: + # 模拟 FMQ consumer + mock_consumer = AsyncMock() + mock_msg = Mock() + mock_msg.payload = sample_request_data + mock_fmq_factory.q_a2e_consumer.return_value = mock_consumer + mock_engine_service.fmq_a2e_consumer = mock_consumer + + # 模拟 Request.from_dict 返回 + mock_request = Mock(spec=Request) + mock_request.request_id = "test_req_123" + mock_request.metrics = RequestMetrics() + mock_from_dict.return_value = mock_request + + # 模拟 scheduler.put_requests 返回 + mock_engine_service.scheduler.put_requests.return_value = [("test_req_123", None)] + + # 模拟 metrics + mock_metrics.requests_number = Mock() + mock_metrics.num_requests_waiting = Mock() + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + call_count = 0 + + async def mock_get(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_msg + else: + # 抛出异常来退出循环,而不是返回 None + raise Exception("FMQ connection terminated") + + mock_consumer.get.side_effect = mock_get + + # 执行函数 + await bound_method() + + # 验证调用 + mock_consumer.get.assert_called() + mock_from_dict.assert_called_once_with(sample_request_data) + mock_engine_service.scheduler.put_requests.assert_called_once() + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_decode_role_early_return(self, mock_engine_service): + """测试 splitwise_role='decode' 时的早期返回""" + mock_engine_service.cfg.scheduler_config.splitwise_role = "decode" + + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数,应该立即返回 + await bound_method() + + # 验证没有进行任何网络调用 + assert ( + not hasattr(mock_engine_service, "recv_request_server") + or mock_engine_service.recv_request_server is None + ) + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_request_error(self, mock_engine_service, sample_request_data): + """测试请求解析错误的情况""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + with patch("fastdeploy.engine.common_engine.Request.from_dict") as mock_from_dict: + with patch("fastdeploy.engine.common_engine.main_process_metrics") as mock_metrics: + # 模拟 recv_request_server + mock_recv_server = Mock() + mock_recv_server.receive_json_once.return_value = (None, sample_request_data) + mock_engine_service.recv_request_server = mock_recv_server + + # 模拟 Request.from_dict 抛出异常 + mock_from_dict.side_effect = Exception("Invalid request data") + + # 模拟 _send_error_response + mock_engine_service._send_error_response = Mock() + + # 模拟 scheduler.put_requests 返回 + mock_engine_service.scheduler.put_requests.return_value = [] + + # 模拟 metrics + mock_metrics.requests_number = Mock() + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + call_count = 0 + + def mock_receive_json_once(block): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (None, sample_request_data) + else: + return (Exception("Context was terminated"), None) + + mock_recv_server.receive_json_once.side_effect = mock_receive_json_once + + # 执行函数 + await bound_method() + + # 验证错误处理 + mock_from_dict.assert_called_once_with(sample_request_data) + mock_engine_service._send_error_response.assert_called_once() + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_guided_decoding_error(self, mock_engine_service, sample_request_data): + """测试 guided_decoding_checker 错误的情况""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + with patch("fastdeploy.engine.common_engine.Request.from_dict") as mock_from_dict: + with patch("fastdeploy.engine.common_engine.main_process_metrics") as mock_metrics: + # 模拟 recv_request_server + mock_recv_server = Mock() + mock_recv_server.receive_json_once.return_value = (None, sample_request_data) + mock_engine_service.recv_request_server = mock_recv_server + + # 模拟 Request.from_dict 返回 + mock_request = Mock(spec=Request) + mock_request.request_id = "test_req_123" + mock_request.metrics = RequestMetrics() + mock_from_dict.return_value = mock_request + + # 模拟 guided_decoding_checker + mock_checker = Mock() + mock_checker.schema_format.return_value = (mock_request, "Schema validation error") + mock_engine_service.guided_decoding_checker = mock_checker + + # 模拟 _send_error_response + mock_engine_service._send_error_response = Mock() + + # 模拟 scheduler.put_requests 返回 + mock_engine_service.scheduler.put_requests.return_value = [] + + # 模拟 metrics + mock_metrics.requests_number = Mock() + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + call_count = 0 + + def mock_receive_json_once(block): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (None, sample_request_data) + else: + return (Exception("Context was terminated"), None) + + mock_recv_server.receive_json_once.side_effect = mock_receive_json_once + + # 执行函数 + await bound_method() + + # 验证 guided_decoding_checker 调用 + mock_checker.schema_format.assert_called_once_with(mock_request) + mock_engine_service._send_error_response.assert_called_once() + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_fmq_get_exception(self, mock_engine_service): + """测试 FMQ consumer.get() 异常的情况""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", False): + with patch("fastdeploy.engine.common_engine.FMQFactory") as mock_fmq_factory: + # 模拟 FMQ consumer + mock_consumer = AsyncMock() + mock_consumer.get.side_effect = Exception("FMQ connection error") + mock_fmq_factory.q_a2e_consumer.return_value = mock_consumer + mock_engine_service.fmq_a2e_consumer = mock_consumer + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数,应该因为异常而退出循环 + await bound_method() + + # 验证 consumer.get 被调用 + mock_consumer.get.assert_called() + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_zmq_context_terminated(self, mock_engine_service): + """测试 ZMQ context 终止的正常关闭情况""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + # 模拟 recv_request_server + mock_recv_server = Mock() + mock_recv_server.receive_json_once.return_value = (Exception("Context was terminated"), None) + mock_engine_service.recv_request_server = mock_recv_server + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + # 执行函数,应该因为 Context was terminated 而正常退出 + await bound_method() + + # 验证 receive_json_once 被调用 + mock_recv_server.receive_json_once.assert_called() + # 验证记录了 info 日志 + mock_engine_service.llm_logger.info.assert_called() + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_multiple_requests(self, mock_engine_service, sample_request_data): + """测试处理多个请求的情况""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + with patch("fastdeploy.engine.common_engine.Request.from_dict") as mock_from_dict: + with patch("fastdeploy.engine.common_engine.main_process_metrics") as mock_metrics: + # 模拟 recv_request_server + mock_recv_server = Mock() + mock_engine_service.recv_request_server = mock_recv_server + + # 模拟 Request.from_dict 返回 + mock_request = Mock(spec=Request) + mock_request.request_id = "test_req_123" + mock_request.metrics = RequestMetrics() + mock_from_dict.return_value = mock_request + + # 模拟 scheduler.put_requests 返回 + mock_engine_service.scheduler.put_requests.return_value = [("test_req_123", None)] + + # 模拟 metrics + mock_metrics.requests_number = Mock() + mock_metrics.num_requests_waiting = Mock() + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + call_count = 0 + + def mock_receive_json_once(block): + nonlocal call_count + call_count += 1 + if call_count <= 2: # 处理两个请求 + return (None, sample_request_data) + else: + return (Exception("Context was terminated"), None) + + mock_recv_server.receive_json_once.side_effect = mock_receive_json_once + + # 执行函数 + await bound_method() + + # 验证多次调用 + assert mock_recv_server.receive_json_once.call_count == 3 + assert mock_from_dict.call_count == 2 + assert mock_engine_service.scheduler.put_requests.call_count == 2 + assert mock_metrics.requests_number.inc.call_count == 2 + assert mock_metrics.num_requests_waiting.inc.call_count == 2 + + @pytest.mark.asyncio + async def test_insert_zmq_task_to_scheduler_block_parameter(self, mock_engine_service, sample_request_data): + """测试 block 参数的逻辑""" + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True): + with patch("fastdeploy.engine.common_engine.Request.from_dict") as mock_from_dict: + with patch("fastdeploy.engine.common_engine.main_process_metrics") as mock_metrics: + # 模拟 recv_request_server + mock_recv_server = Mock() + mock_engine_service.recv_request_server = mock_recv_server + + # 模拟 Request.from_dict 返回 + mock_request = Mock(spec=Request) + mock_request.request_id = "test_req_123" + mock_request.metrics = RequestMetrics() + mock_from_dict.return_value = mock_request + + # 模拟 scheduler.put_requests 返回成功,请求会从 added_requests 中移除 + mock_engine_service.scheduler.put_requests.return_value = [("test_req_123", None)] + mock_engine_service._send_error_response = Mock() + + # 模拟 metrics + mock_metrics.requests_number = Mock() + mock_metrics.num_requests_waiting = Mock() + + # 创建真实的函数并调用 + real_function = EngineService._insert_zmq_task_to_scheduler + bound_method = real_function.__get__(mock_engine_service, EngineService) + + calls = [] + + def mock_receive_json_once(block): + calls.append(block) + if len(calls) == 1: + # 第一次调用,added_requests 为空,block 应该是 True + return (None, sample_request_data) + elif len(calls) == 2: + # 第二次调用,added_requests 为空(因为第一个请求成功处理并移除),block 应该是 True + return (None, sample_request_data) + else: + return (Exception("Context was terminated"), None) + + mock_recv_server.receive_json_once.side_effect = mock_receive_json_once + + # 执行函数 + await bound_method() + + # 验证 block 参数的逻辑 + assert calls[0] is True # 第一次调用,added_requests 为空 + assert calls[1] is True # 第二次调用,added_requests 也为空(第一个请求已成功处理) diff --git a/tests/engine/test_start_zmq_service_thread.py b/tests/engine/test_start_zmq_service_thread.py new file mode 100644 index 00000000000..f9cbc1e7b11 --- /dev/null +++ b/tests/engine/test_start_zmq_service_thread.py @@ -0,0 +1,267 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import threading +import unittest +from unittest.mock import MagicMock, Mock, patch + +from fastdeploy.engine.common_engine import EngineService + + +class TestStartZmqServiceThread(unittest.TestCase): + """Test case for start_zmq_service method thread creation""" + + def setUp(self): + """Set up for each test method""" + # Create a mock config to avoid model loading issues + self.cfg = MagicMock() + self.cfg.parallel_config.local_engine_worker_queue_port = 6808 + self.cfg.parallel_config.local_data_parallel_id = 0 + self.cfg.scheduler_config.splitwise_role = "mixed" + self.cfg.model_config.enable_mm = False + + # Mock the scheduler and other dependencies + self.cfg.scheduler_config.scheduler.return_value = MagicMock() + self.cfg.max_num_partial_prefills = 1 + self.cfg.scheduler_config.max_num_batched_tokens = 1024 + self.cfg.cache_config.block_size = 16 + self.cfg.cache_config.enable_prefix_caching = False + self.cfg.structured_outputs_config.guided_decoding_backend = "off" + self.cfg.eplb_config.enable_eplb = False + self.cfg.router_config.router = None + self.cfg.max_prefill_batch = 1 + self.cfg.limit_mm_per_prompt = 1 + self.cfg.mm_processor_kwargs = {} + self.cfg.tool_parser = None + self.cfg.cache_config.num_gpu_blocks_override = 4 + self.cfg.worker_num_per_node = 1 + self.cfg.parallel_config.tensor_parallel_size = 1 + self.cfg.parallel_config.data_parallel_size = 1 + self.cfg.parallel_config.device_ids = "0" + self.cfg.parallel_config.engine_worker_queue_port = [6808] + self.cfg.cache_config.local_cache_queue_port = 6809 + self.cfg.master_ip = "127.0.0.1" + self.cfg.host_ip = "127.0.0.1" + self.cfg.enable_decode_cache_task = False + self.cfg.splitwise_version = "v1" + self.cfg.register_info = {} + self.cfg.parallel_config.enable_expert_parallel = False + self.cfg.parallel_config.local_engine_worker_queue_port = 6808 + self.cfg.parallel_config.engine_worker_queue_port = 6808 + self.cfg.cache_config.enc_dec_block_num = 0 + self.cfg.cache_config.max_block_num_per_seq = 100 + self.cfg.model_config.max_model_len = 1024 + self.cfg.model_config.enable_mm = False + self.cfg.structured_outputs_config.disable_any_whitespace = False + self.cfg.structured_outputs_config.reasoning_parser = None + self.cfg.scheduler_config.splitwise_role = "mixed" + self.cfg.cache_config.enable_chunked_prefill = True + self.cfg.scheduler_config.max_num_seqs = 1 + self.cfg.cache_config.num_cpu_blocks = 0 + self.cfg.cache_config.total_block_num = 100 + self.cfg.cache_config.prefill_kvcache_block_num = 100 + self.cfg.speculative_config = MagicMock() + self.cfg.cache_config.max_block_num_per_seq = 100 + self.cfg.cache_config.enc_dec_block_num = 0 + self.cfg.cache_config.block_size = 16 + self.cfg.cache_config.enable_prefix_caching = False + self.cfg.cache_config.num_gpu_blocks_override = 4 + self.cfg.cache_config.kv_cache_ratio = 0.75 + + def test_start_zmq_service_thread_creation(self): + """Test that insert_task_to_scheduler_thread is created correctly in start_zmq_service""" + + # Mock the dependencies to avoid actual ZMQ setup and queue connections + with ( + patch("fastdeploy.engine.common_engine.ZmqTcpServer") as mock_zmq_tcp, + patch("fastdeploy.engine.common_engine.ZmqIpcServer") as mock_zmq_ipc, + patch("fastdeploy.engine.common_engine.InternalAdapter") as mock_internal_adapter, + patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_INTERNAL_ADAPTER", True), + patch("fastdeploy.engine.common_engine.time.sleep"), + patch("fastdeploy.engine.common_engine.EngineWorkerQueue") as mock_worker_queue, + patch("fastdeploy.engine.common_engine.EngineCacheQueue") as mock_cache_queue, + ): + mock_zmq_ipc = mock_zmq_ipc if mock_zmq_ipc is not None else mock_zmq_ipc + mock_internal_adapter = ( + mock_internal_adapter if mock_internal_adapter is not None else mock_internal_adapter + ) + # Create mock queue instances + mock_queue_instance = MagicMock() + mock_queue_instance.get_server_port.return_value = 6808 + mock_worker_queue.return_value = mock_queue_instance + mock_cache_queue.return_value = mock_queue_instance + + # Create engine service without starting full services + engine = EngineService(self.cfg, start_queue=False, use_async_llm=False) + engine.running = True # Add running attribute to prevent thread errors + + # Mock the send_response_server.recv_result_handle method + mock_recv_result_handle = Mock() + + # Create mock servers + mock_recv_server = Mock() + mock_send_server = Mock() + mock_send_server.recv_result_handle = mock_recv_result_handle + + mock_zmq_tcp.side_effect = [mock_recv_server, mock_send_server] + + # Call start_zmq_service with a test PID + api_server_pid = "test_pid_12345" + engine.start_zmq_service(api_server_pid) + + # Verify that insert_task_to_scheduler_thread was created + self.assertTrue(hasattr(engine, "insert_task_to_scheduler_thread")) + self.assertIsInstance(engine.insert_task_to_scheduler_thread, threading.Thread) + + # Verify thread is configured correctly - use _target instead of target for Thread objects + self.assertEqual(engine.insert_task_to_scheduler_thread._target, engine._run_insert_zmq_task_to_scheduler) + self.assertTrue(engine.insert_task_to_scheduler_thread.daemon) + + # Verify thread was started + self.assertTrue( + engine.insert_task_to_scheduler_thread.is_alive() + or engine.insert_task_to_scheduler_thread.ident is not None + ) + + # Clean up + engine.running = False # Stop the thread loop + if engine.insert_task_to_scheduler_thread.is_alive(): + engine.insert_task_to_scheduler_thread.join(timeout=1) + + def test_run_insert_zmq_task_to_scheduler_success(self): + """Test _run_insert_zmq_task_to_scheduler successful execution""" + + with ( + patch("fastdeploy.engine.common_engine.EngineWorkerQueue") as mock_worker_queue, + patch("fastdeploy.engine.common_engine.EngineCacheQueue") as mock_cache_queue, + patch("fastdeploy.engine.common_engine.asyncio.run") as mock_asyncio_run, + ): + + # Create mock queue instances + mock_queue_instance = MagicMock() + mock_queue_instance.get_server_port.return_value = 6808 + mock_worker_queue.return_value = mock_queue_instance + mock_cache_queue.return_value = mock_queue_instance + + engine = EngineService(self.cfg, start_queue=False, use_async_llm=False) + engine.running = True + + # Mock the async method to avoid actual asyncio loop + mock_asyncio_run.return_value = None + + # Call the method + engine._run_insert_zmq_task_to_scheduler() + + # Verify asyncio.run was called once + mock_asyncio_run.assert_called_once() + # Verify it was called with a coroutine object (the async method) + call_args = mock_asyncio_run.call_args[0] + self.assertEqual(len(call_args), 1) + # Check that the argument is a coroutine from the correct method + import inspect + + self.assertTrue(inspect.iscoroutine(call_args[0])) + # Check the coroutine's function name matches our target method + self.assertEqual(call_args[0].cr_code.co_name, "_insert_zmq_task_to_scheduler") + + def test_run_insert_zmq_task_to_scheduler_exception_handling(self): + """Test _run_insert_zmq_task_to_scheduler exception handling""" + + with ( + patch("fastdeploy.engine.common_engine.EngineWorkerQueue") as mock_worker_queue, + patch("fastdeploy.engine.common_engine.EngineCacheQueue") as mock_cache_queue, + patch("fastdeploy.engine.common_engine.asyncio.run") as mock_asyncio_run, + ): + + # Create mock queue instances + mock_queue_instance = MagicMock() + mock_queue_instance.get_server_port.return_value = 6808 + mock_worker_queue.return_value = mock_queue_instance + mock_cache_queue.return_value = mock_queue_instance + + engine = EngineService(self.cfg, start_queue=False, use_async_llm=False) + engine.running = True + + # Mock asyncio.run to raise an exception + test_exception = Exception("Test async loop error") + mock_asyncio_run.side_effect = test_exception + + # Mock the logger + engine.llm_logger = MagicMock() + + # Call the method + engine._run_insert_zmq_task_to_scheduler() + + # Verify the exception was caught and logged + mock_asyncio_run.assert_called_once() + # Verify it was called with a coroutine object + call_args = mock_asyncio_run.call_args[0] + self.assertEqual(len(call_args), 1) + import inspect + + self.assertTrue(inspect.iscoroutine(call_args[0])) + self.assertEqual(call_args[0].cr_code.co_name, "_insert_zmq_task_to_scheduler") + engine.llm_logger.error.assert_called_once_with("Async loop crashed: Test async loop error") + + def test_run_insert_zmq_task_to_scheduler_various_exceptions(self): + """Test _run_insert_zmq_task_to_scheduler with different types of exceptions""" + + with ( + patch("fastdeploy.engine.common_engine.EngineWorkerQueue") as mock_worker_queue, + patch("fastdeploy.engine.common_engine.EngineCacheQueue") as mock_cache_queue, + patch("fastdeploy.engine.common_engine.asyncio.run") as mock_asyncio_run, + ): + + # Create mock queue instances + mock_queue_instance = MagicMock() + mock_queue_instance.get_server_port.return_value = 6808 + mock_worker_queue.return_value = mock_queue_instance + mock_cache_queue.return_value = mock_queue_instance + + engine = EngineService(self.cfg, start_queue=False, use_async_llm=False) + engine.running = True + + # Test a couple of key exception types + test_exceptions = [ + RuntimeError("Runtime error"), + ValueError("Value error"), + ] + + for test_exception in test_exceptions: + # Reset the mock + mock_asyncio_run.reset_mock() + engine.llm_logger = MagicMock() + + # Mock asyncio.run to raise the test exception + mock_asyncio_run.side_effect = test_exception + + # Call the method + engine._run_insert_zmq_task_to_scheduler() + + # Verify the exception was caught and logged + mock_asyncio_run.assert_called_once() + # Verify it was called with a coroutine object + call_args = mock_asyncio_run.call_args[0] + self.assertEqual(len(call_args), 1) + import inspect + + self.assertTrue(inspect.iscoroutine(call_args[0])) + self.assertEqual(call_args[0].cr_code.co_name, "_insert_zmq_task_to_scheduler") + engine.llm_logger.error.assert_called_once_with(f"Async loop crashed: {test_exception}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/test_engine_client_send_task.py b/tests/entrypoints/test_engine_client_send_task.py new file mode 100644 index 00000000000..c431a2d5539 --- /dev/null +++ b/tests/entrypoints/test_engine_client_send_task.py @@ -0,0 +1,142 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from fastdeploy.entrypoints.engine_client import EngineClient + + +class TestEngineClientSendTask: + """测试 EngineClient._send_task 方法""" + + @pytest.fixture + def sample_task(self): + """示例任务数据""" + return {"request_id": "test_req_123", "prompt": "Hello, world!", "max_tokens": 100} + + @pytest.mark.asyncio + @patch("fastdeploy.entrypoints.engine_client.api_server_logger") + @patch("fastdeploy.entrypoints.engine_client.FMQFactory") + async def test_send_task_success_with_new_producer(self, mock_fmq_factory, mock_logger, sample_task): + """测试成功发送任务,需要创建新的 producer""" + # 创建真实的 EngineClient 实例 + client = EngineClient.__new__(EngineClient) + client.fmq_a2e_producer = None # 初始为 None + + # 模拟 FMQ producer + mock_producer = AsyncMock() + mock_fmq_factory.q_a2e_producer.return_value = mock_producer + + # 执行函数 + await client._send_task(sample_task) + + # 验证 producer 被创建 + mock_fmq_factory.q_a2e_producer.assert_called_once() + # 验证 producer.put 被调用 + mock_producer.put.assert_called_once_with(sample_task) + # 验证 fmq_a2e_producer 被设置 + assert client.fmq_a2e_producer == mock_producer + # 验证没有记录错误日志 + mock_logger.error.assert_not_called() + + @pytest.mark.asyncio + @patch("fastdeploy.entrypoints.engine_client.api_server_logger") + async def test_send_task_success_with_existing_producer(self, mock_logger, sample_task): + """测试成功发送任务,使用现有的 producer""" + # 创建真实的 EngineClient 实例 + client = EngineClient.__new__(EngineClient) + + # 模拟现有的 FMQ producer + mock_producer = AsyncMock() + client.fmq_a2e_producer = mock_producer + + # 执行函数 + await client._send_task(sample_task) + + # 验证 producer.put 被调用 + mock_producer.put.assert_called_once_with(sample_task) + # 验证没有记录错误日志 + mock_logger.error.assert_not_called() + + @pytest.mark.asyncio + @patch("fastdeploy.entrypoints.engine_client.api_server_logger") + @patch("fastdeploy.entrypoints.engine_client.FMQFactory") + async def test_send_task_producer_put_exception(self, mock_fmq_factory, mock_logger, sample_task): + """测试 producer.put 抛出异常的情况""" + # 创建真实的 EngineClient 实例 + client = EngineClient.__new__(EngineClient) + client.fmq_a2e_producer = None + + # 模拟 FMQ producer + mock_producer = AsyncMock() + mock_producer.put.side_effect = Exception("Connection failed") + mock_fmq_factory.q_a2e_producer.return_value = mock_producer + + # 执行函数 + await client._send_task(sample_task) + + # 验证 producer 被创建 + mock_fmq_factory.q_a2e_producer.assert_called_once() + # 验证 producer.put 被调用 + mock_producer.put.assert_called_once_with(sample_task) + # 验证异常被捕获并记录错误日志 + mock_logger.error.assert_called_once() + error_call_args = mock_logger.error.call_args[0][0] + assert "Failed to send task via FMQ: Connection failed" in error_call_args + + @pytest.mark.asyncio + @patch("fastdeploy.entrypoints.engine_client.api_server_logger") + @patch("fastdeploy.entrypoints.engine_client.FMQFactory") + async def test_send_task_get_producer_exception(self, mock_fmq_factory, mock_logger, sample_task): + """测试 _get_producer 抛出异常的情况""" + # 创建真实的 EngineClient 实例 + client = EngineClient.__new__(EngineClient) + client.fmq_a2e_producer = None + + # 模拟 FMQFactory 抛出异常 + mock_fmq_factory.q_a2e_producer.side_effect = Exception("Factory initialization failed") + + # 执行函数 + await client._send_task(sample_task) + + # 验证 FMQFactory 被调用 + mock_fmq_factory.q_a2e_producer.assert_called_once() + # 验证异常被捕获并记录错误日志 + mock_logger.error.assert_called_once() + error_call_args = mock_logger.error.call_args[0][0] + assert "Failed to send task via FMQ: Factory initialization failed" in error_call_args + + def test_get_producer_returns_existing(self): + """测试 _get_producer 方法返回现有 producer""" + # 创建真实的 EngineClient 实例 + client = EngineClient.__new__(EngineClient) + + # 模拟现有的 FMQ producer + mock_producer = Mock() + client.fmq_a2e_producer = mock_producer + + # 执行函数 + result = client._get_producer() + + # 验证返回现有的 producer + assert result == mock_producer + + @patch("fastdeploy.entrypoints.engine_client.FMQFactory") + def test_get_producer_creates_new(self, mock_fmq_factory): + """测试 _get_producer 方法创建新的 producer""" + # 创建真实的 EngineClient 实例 + client = EngineClient.__new__(EngineClient) + client.fmq_a2e_producer = None + + # 模拟 FMQ producer + mock_producer = Mock() + mock_fmq_factory.q_a2e_producer.return_value = mock_producer + + # 执行函数 + result = client._get_producer() + + # 验证 FMQFactory 被调用 + mock_fmq_factory.q_a2e_producer.assert_called_once() + # 验证返回新的 producer + assert result == mock_producer + # 验证 fmq_a2e_producer 被设置 + assert client.fmq_a2e_producer == mock_producer