Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions fastdeploy/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import FMQFactory, IPCSignal
from fastdeploy.inter_communicator.zmq_client import ZmqIpcClient
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import EngineError, llm_logger
Expand Down Expand Up @@ -299,6 +299,7 @@ def __init__(self, cfg, pid):
# Create high-performance async connection manager
self.connection_manager = None
self.request_client = None
self.fmq_a2e_producer = None

# Output processor uses data_processor for post-processing engine outputs
self.output_processor = AsyncOutputProcessor(self.data_processor)
Expand All @@ -307,6 +308,11 @@ def __init__(self, cfg, pid):

main_process_metrics.set_cache_config_info(obj=self.cfg.cache_config)

def _get_producer(self):
if self.fmq_a2e_producer is None:
self.fmq_a2e_producer = FMQFactory.q_a2e_producer()
return self.fmq_a2e_producer

async def init_connections(self):
"""Initialize high-performance ZMQ connections"""
try:
Expand Down Expand Up @@ -439,10 +445,11 @@ async def add_request(
f"preprocess time cost {preprocess_cost_time}"
)

if not self.cfg.model_config.enable_mm:
self.request_client.send_json(request)
else:
self.request_client.send_pyobj(request)
try:
producer = self._get_producer()
await producer.put(request)
except Exception as e:
llm_logger.error(f"Failed to send task via FMQ: {e}")

except EngineError:
raise
Expand Down Expand Up @@ -603,6 +610,16 @@ async def shutdown(self):
self.request_client.close()
except Exception as e:
llm_logger.warning(f"Error closing request client: {e}")
# Close FMQ producer
if hasattr(self, "fmq_a2e_producer") and self.fmq_a2e_producer is not None:
try:
if hasattr(self.fmq_a2e_producer, "socket") and self.fmq_a2e_producer.socket is not None:
self.fmq_a2e_producer.socket.close()
llm_logger.info("FMQ producer socket closed successfully.")
except Exception as e:
llm_logger.error(f"Error closing fmq_producer: {e}")
finally:
self.fmq_a2e_producer = None

# Shutdown engine service process
try:
Expand Down
43 changes: 37 additions & 6 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import copy
import json
import multiprocessing
Expand Down Expand Up @@ -46,6 +47,7 @@
from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
FMQFactory,
IPCSignal,
ZmqIpcServer,
ZmqTcpServer,
Expand Down Expand Up @@ -1027,20 +1029,28 @@ def start_zmq_service(self, api_server_pid=None):
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
)
else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
self.fmq_a2e_consumer = None
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
time.sleep(3)
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
self.insert_task_to_scheduler_thread = threading.Thread(
target=self._run_insert_zmq_task_to_scheduler, daemon=True
)
self.insert_task_to_scheduler_thread.start()

self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread.start()

def _insert_zmq_task_to_scheduler(self):
def _run_insert_zmq_task_to_scheduler(self):
try:
asyncio.run(self._insert_zmq_task_to_scheduler())
except Exception as e:
self.llm_logger.error(f"Async loop crashed: {e}")

async def _insert_zmq_task_to_scheduler(self):
tracing.trace_set_thread_info("Insert Task to Scheduler")
added_requests: Dict[str, int] = dict()
if envs.FD_ENABLE_INTERNAL_ADAPTER:
Expand All @@ -1050,10 +1060,22 @@ def _insert_zmq_task_to_scheduler(self):
while self.running:
try:
block = True if len(added_requests) == 0 else False
if not self.cfg.model_config.enable_mm:
err, data = self.recv_request_server.receive_json_once(block)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not self.cfg.model_config.enable_mm:
err, data = self.recv_request_server.receive_json_once(block)
else:
err, data = self.recv_request_server.receive_pyobj_once(block)
else:
err, data = self.recv_request_server.receive_pyobj_once(block)
err = None
if self.fmq_a2e_consumer is None:
self.fmq_a2e_consumer = FMQFactory.q_a2e_consumer()
try:
msg = await self.fmq_a2e_consumer.get()
if msg is None:
continue
data = msg.payload
except Exception as e:
err = e
if err is not None:
# The message "Context was terminated" is normal when closing a ZMQ context
if "Context was terminated" in str(err):
Expand Down Expand Up @@ -1493,6 +1515,15 @@ def _exit_sub_services(self):
self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close()
if hasattr(self, "fmq_a2e_consumer") and self.fmq_a2e_consumer is not None:
try:
if hasattr(self.fmq_a2e_consumer, "socket") and self.fmq_a2e_consumer.socket is not None:
self.fmq_a2e_consumer.socket.close()
llm_logger.info("FMQ consumer socket closed successfully.")
except Exception as e:
llm_logger.error(f"Error closing fmq_consumer: {e}")
finally:
self.fmq_a2e_consumer = None

# 从 async_llm 移到 common_engine
def _worker_processes_ready(self):
Expand Down
27 changes: 16 additions & 11 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from fastdeploy.eplb.utils import RedundantExpertWorkload
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
FMQFactory,
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
Expand All @@ -49,7 +50,6 @@
ParameterError,
StatefulSemaphore,
api_server_logger,
to_tensor,
)


Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers
self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching
self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed"
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
self.fmq_a2e_producer = None

if self.tensor_parallel_size <= self.max_chips_per_node:
self.is_master = True
Expand Down Expand Up @@ -348,7 +349,7 @@ async def add_requests(self, task):
request_id_idx = task.get("request_id")
parts = request_id_idx.rsplit("_", 1)
if len(parts) == 1:
self._send_task(task)
await self._send_task(task)
else:
request_id = parts[0]
index = int(parts[1])
Expand All @@ -357,21 +358,25 @@ async def add_requests(self, task):
for i in range(index * n, (index + 1) * n):
child_task = copy(task)
child_task["request_id"] = f"{request_id}_{i}"
self._send_task(child_task)
await self._send_task(child_task)
tracing.trace_slice_end(
tracing.TraceSpanName.PREPROCESSING, task.get("request_id").split("_")[0], thread_finish_flag=True
)
except Exception as e:
api_server_logger.error(f"zmq_client send task error: {e}, {str(traceback.format_exc())}")
api_server_logger.error(f"fmq send task error: {e}, {str(traceback.format_exc())}")
raise EngineError(str(e), error_code=400)

def _send_task(self, task):
if not self.enable_mm:
self.zmq_client.send_json(task)
else:
if envs.FD_ENABLE_E2W_TENSOR_CONVERT:
to_tensor([task])
self.zmq_client.send_pyobj(task)
def _get_producer(self):
if self.fmq_a2e_producer is None:
self.fmq_a2e_producer = FMQFactory.q_a2e_producer()
return self.fmq_a2e_producer

async def _send_task(self, task):
try:
producer = self._get_producer()
await producer.put(task)
except Exception as e:
api_server_logger.error(f"Failed to send task via FMQ: {e}")

def valid_parameters(self, data):
"""
Expand Down
15 changes: 12 additions & 3 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from contextlib import asynccontextmanager

import uvicorn
import zmq
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
Expand Down Expand Up @@ -220,7 +219,6 @@ async def lifespan(app: FastAPI):
reward_handler = OpenAIServingReward(
engine_client, app.state.model_handler, fd_config, pid, args.ips, args.max_waiting_time, chat_template
)
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid
app.state.engine_client = engine_client
app.state.chat_handler = chat_handler
Expand All @@ -234,7 +232,18 @@ async def lifespan(app: FastAPI):
# close zmq
try:
await engine_client.connection_manager.close()
engine_client.zmq_client.close()
if hasattr(engine_client, "fmq_a2e_producer") and engine_client.fmq_a2e_producer is not None:
try:
if (
hasattr(engine_client.fmq_a2e_producer, "socket")
and engine_client.fmq_a2e_producer.socket is not None
):
engine_client.fmq_a2e_producer.socket.close()
api_server_logger.info("FMQ producer socket closed successfully.")
except Exception as e:
api_server_logger.error(f"Error closing fmq_producer: {e}")
finally:
engine_client.fmq_a2e_producer = None
from prometheus_client import multiprocess

multiprocess.mark_process_dead(os.getpid())
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/inter_communicator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .fmq import FMQ
from .fmq_factory import FMQFactory
from .ipc_signal import IPCSignal, shared_memory_exists
from .ipc_signal_const import (
ExistTaskStatus,
Expand All @@ -40,4 +42,6 @@
"ModelWeightsStatus",
"KVCacheStatus",
"RearrangeExpertStatus",
"FMQ",
"FMQFactory",
]
37 changes: 24 additions & 13 deletions fastdeploy/inter_communicator/fmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import zmq.asyncio

from fastdeploy import envs
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.stats import ZMQMetricsStats
from fastdeploy.utils import fmq_logger

# ==========================
Expand Down Expand Up @@ -232,19 +234,28 @@ async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
if self.role != Role.PRODUCER:
raise PermissionError("Only producers can send messages.")

desc = None
payload = data

if isinstance(data, bytes) and len(data) >= shm_threshold:
desc = Descriptor.create(data)
payload = None

msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc)
raw = msg.serialize()

async with self.lock:
await self.socket.send(raw, copy=self.copy)
self._msg_id += 1
_zmq_metrics_stats = ZMQMetricsStats()
try:
desc = None
payload = data

if isinstance(data, bytes) and len(data) >= shm_threshold:
desc = Descriptor.create(data)
payload = None

msg = Message(msg_id=self._msg_id, payload=payload, descriptor=desc, timestamp=time.perf_counter())
raw = msg.serialize()
_zmq_metrics_stats.msg_bytes_send_total += len(raw)

async with self.lock:
await self.socket.send(raw, copy=self.copy)
self._msg_id += 1
except Exception as e:
_zmq_metrics_stats.msg_send_failed_total += 1
fmq_logger.error(f"Failed to send message: {e}")
finally:
_zmq_metrics_stats.msg_send_total += 1
main_process_metrics.record_zmq_stats(_zmq_metrics_stats, self.endpoint.address)

async def get(self, timeout: int = None) -> Optional[Message]:
# Receive data from queue
Expand Down
Loading
Loading