Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,6 @@ def __post_init__(self):
self.tokenizer = self.model
if self.splitwise_role == "decode":
self.enable_prefix_caching = False
if self.speculative_config is not None:
self.enable_prefix_caching = False
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
self.enable_prefix_caching = False
# if self.dynamic_load_weight:
Expand Down
37 changes: 31 additions & 6 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,10 @@ def _fetch_request():

self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
for task in tasks:
Expand Down Expand Up @@ -770,15 +774,36 @@ def _fetch_request():
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
finished_ids, delete_tasks_list = [], []
while need_check_req_ids:
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
if req_ids:
for req_id in req_ids:
assert req_id in need_check_req_ids
need_check_req_ids.remove(req_id)
finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req())
self.llm_logger.info(f"get_finished_add_cache_task_req: {finished_ids}")
if finished_ids:
for task in tasks:
result = self.resource_manager.waiting_async_process(task)
if result is None:
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=task.error_code,
error_msg=task.error_message,
)
]
)
delete_tasks_list.append(task)
elif result is False:
if task.request_id in finished_ids:
need_check_req_ids.remove(task.request_id)
finished_ids.remove(task.request_id)
else:
time.sleep(0.001)

for tmp_task in delete_tasks_list:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:
Expand Down
16 changes: 16 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,22 @@ def __init__(
self.error_message = None
self.error_code = None

def __getstate__(self):
"""
Custom getstate method for pickle support.
Handles unpicklable attributes by filtering them from __dict__.
"""
# Create a filtered dictionary without problematic attributes
filtered_dict = {}
for key, value in self.__dict__.items():
# Skip attributes that are known to contain unpicklable objects
if key == "async_process_futures":
filtered_dict[key] = []
else:
filtered_dict[key] = value

return filtered_dict

@classmethod
def from_dict(cls, d: dict):
data_processor_logger.debug(f"{d}")
Expand Down
8 changes: 4 additions & 4 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def _allocate_decode_and_extend():
):
break
if request.status == RequestStatus.WAITING:
result = self._waiting_async_process(request)
result = self.waiting_async_process(request)
if result is None:
error_reqs.append((request.request_id, request.error_message))
self.waiting.popleft()
Expand Down Expand Up @@ -761,7 +761,7 @@ def _allocate_decode_and_extend():

return scheduled_reqs, error_reqs

def _waiting_async_process(self, request: Request) -> None:
def waiting_async_process(self, request: Request) -> None:
"""
Check if async preprocessing is complete for a request.
Args:
Expand All @@ -780,7 +780,7 @@ def _waiting_async_process(self, request: Request) -> None:
request.async_process_futures = []
return False

def _apply_async_preprocess(self, request: Request) -> None:
def apply_async_preprocess(self, request: Request) -> None:
request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request))

def _has_features_info(self, task):
Expand Down Expand Up @@ -903,7 +903,7 @@ def get_prefix_cached_blocks(self, request: Request):

def add_request(self, request: Request) -> None:
with self.lock:
self._apply_async_preprocess(request)
self.apply_async_preprocess(request)
self.waiting.append(request)
self.requests[request.request_id] = request

Expand Down
4 changes: 1 addition & 3 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
if not profile and self.scheduler_config.splitwise_role != "mixed":
cache_kvs_list = []
for i in range(
self.num_main_model_layers,
Expand Down
10 changes: 5 additions & 5 deletions tests/v1/test_resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setUp(self):

def test_waiting_async_process_no_futures(self):
"""Test when there are no async process futures"""
result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertFalse(result)

def test_waiting_async_process_future_done_no_error(self):
Expand All @@ -63,7 +63,7 @@ def test_waiting_async_process_future_done_no_error(self):
future.set_result(True)
self.request.async_process_futures = [future]

result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertFalse(result)
self.assertEqual(len(self.request.async_process_futures), 0)

Expand All @@ -74,23 +74,23 @@ def test_waiting_async_process_future_done_with_error(self):
self.request.async_process_futures = [future]
self.request.error_message = "Download failed"

result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertIsNone(result)

def test_waiting_async_process_future_not_done(self):
"""Test when future is not done"""
future = concurrent.futures.Future()
self.request.async_process_futures = [future]

result = self.manager._waiting_async_process(self.request)
result = self.manager.waiting_async_process(self.request)
self.assertTrue(result)
self.assertEqual(len(self.request.async_process_futures), 1)

def test_apply_async_preprocess(self):
"""Test applying async preprocess"""
with patch.object(self.manager.async_preprocess_pool, "submit") as mock_submit:
mock_submit.return_value = "mock_future"
self.manager._apply_async_preprocess(self.request)
self.manager.apply_async_preprocess(self.request)

mock_submit.assert_called_once_with(self.manager._download_features, self.request)
self.assertEqual(len(self.request.async_process_futures), 1)
Expand Down
Loading