Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,33 @@ def get_total_slots():

kv_start_indices, attention_mask = [], []
block_num, block_size, _, _ = step_context.kv_caches[0][1].shape
device = step_context.block_offsets.device

is_unpaged_prefill = False
if not step_context.is_decoding:
is_unpaged_prefill = \
all((step_context.q_seqlens ==
step_context.kv_seqlens).tolist())
q_start_loc = torch.cat((torch.tensor([0], device=device), step_context.q_seqlens.cumsum(0))).int()
q_start_loc = step_context.q_start_loc
cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()

q_seqlens = step_context.q_seqlens.int()
kv_seqlens = step_context.kv_seqlens.int()
max_q_seq_len = torch.max(q_seqlens).item()
max_kv_seq_len = torch.max(kv_seqlens).item()

if step_context.is_decoding:
# max_q_seq_len, max_kv_seq_len is not used in decoding stage
max_q_seq_len = -1
max_kv_seq_len = -1

# collect kv_start_indices without using a for-loop,
# (fill kv-cache for just ONE token during the decoding phase)
idx = (step_context.kv_seqlens - 1) % block_size
b_num = (step_context.kv_seqlens - 1) // block_size
last_block = step_context.block_offsets.gather(1, b_num.view(-1, 1)).view(-1)
kv_start_indices = (last_block * block_size + idx).reshape((-1, 1))
else:
max_q_seq_len = torch.max(q_seqlens).cpu().item()
max_kv_seq_len = torch.max(kv_seqlens).cpu().item()

for i in range(step_context.q_start_loc.size(0)):
q_seq_len = int(step_context.q_seqlens[i])
kv_seq_len = int(step_context.kv_seqlens[i])
Expand All @@ -88,7 +94,7 @@ def get_total_slots():
attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets.int(),
q_start_loc=q_start_loc,
q_start_loc=cu_seqlens,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
kv_start_indices=kv_start_indices,
Expand Down
11 changes: 4 additions & 7 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,14 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict
def _init_distributed_environment_by_device(self, device_str: str):
"""Init distributed environment."""
driver_ip = _get_master_addr()
if device_str in ['cuda', 'maca']:
if device_str == 'cuda':
self.workers = self._sort_workers(driver_ip, self.workers)

elif device_str == 'ascend':
self._init_ascend_distributed_environment(driver_ip)
elif device_str == 'camb':
self._init_camb_distributed_environment(driver_ip)
elif device_str in ['camb', 'maca']:
self.workers = self._sort_workers(driver_ip, self.workers)
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
else:
raise ValueError(f'Unsupported device type: {device_str}')

Expand All @@ -590,10 +591,6 @@ def _init_ascend_distributed_environment(self, driver_ip):
else:
self.workers = self._sort_workers(driver_ip, self.workers)

def _init_camb_distributed_environment(self, driver_ip):
self.workers = self._sort_workers(driver_ip, self.workers)
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])

""" PD Disaggregation API Begin """

def p2p_initialize(self, init_request: DistServeInitRequest):
Expand Down